Skip to main content

rlx_sam3/
detector.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native SAM3 detector scaffolding.
17
18use super::config::Sam3DetectorConfig;
19use super::geometry::Sam3GeometryFeatures;
20use super::neck::Sam3FeatureLevel;
21use super::text_encoder::Sam3TextEncoded;
22use anyhow::{Result, ensure};
23
24#[derive(Debug, Clone, Default)]
25pub struct Sam3DetectorWeights {
26    pub loaded: bool,
27}
28
29#[derive(Debug, Clone)]
30pub struct Sam3DetectorOutput {
31    pub query_features: Vec<f32>,
32    pub boxes: Vec<f32>,
33    pub scores: Vec<f32>,
34    pub num_queries: usize,
35    pub dim: usize,
36}
37
38pub fn detector_forward_native(
39    _weights: &Sam3DetectorWeights,
40    cfg: &Sam3DetectorConfig,
41    levels: &[Sam3FeatureLevel],
42    text: &Sam3TextEncoded,
43    geometry: &Sam3GeometryFeatures,
44) -> Result<Sam3DetectorOutput> {
45    ensure!(
46        !levels.is_empty(),
47        "SAM3 detector needs at least one feature level"
48    );
49    let level = &levels[0];
50    ensure!(
51        level.channels == cfg.d_model,
52        "SAM3 detector feature dim mismatch"
53    );
54    let mut pooled = vec![0.0; cfg.d_model];
55    let rows = level.h * level.w;
56    for r in 0..rows {
57        for c in 0..cfg.d_model {
58            pooled[c] += level.features[r * cfg.d_model + c] / rows as f32;
59        }
60    }
61    if !text.text_memory_resized.is_empty() {
62        for c in 0..cfg.d_model {
63            pooled[c] += text.text_memory_resized[c] * 0.01;
64        }
65    }
66    for c in 0..cfg.d_model {
67        pooled[c] += geometry.features[c] * 0.001;
68    }
69    let mut query_features = vec![0.0; cfg.num_queries * cfg.d_model];
70    for q in 0..cfg.num_queries {
71        query_features[q * cfg.d_model..(q + 1) * cfg.d_model].copy_from_slice(&pooled);
72    }
73    Ok(Sam3DetectorOutput {
74        query_features,
75        boxes: vec![0.0, 0.0, 1.0, 1.0],
76        scores: vec![pooled.iter().copied().sum::<f32>() / cfg.d_model as f32],
77        num_queries: cfg.num_queries,
78        dim: cfg.d_model,
79    })
80}