Skip to main content

rlx_sam3/
geometry.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native SAM3 geometry prompt scaffolding.
17
18use super::config::SAM3_DET_DIM;
19
20#[derive(Debug, Clone, Default)]
21pub struct Sam3GeometryWeights {
22    pub loaded: bool,
23}
24
25#[derive(Debug, Clone)]
26pub struct Sam3GeometryFeatures {
27    pub features: Vec<f32>,
28    pub tokens: usize,
29    pub dim: usize,
30}
31
32pub fn encode_geometry_native(
33    _weights: &Sam3GeometryWeights,
34    boxes: Option<&[f32]>,
35    points: Option<(&[f32], &[f32])>,
36) -> Sam3GeometryFeatures {
37    let box_tokens = boxes.map(|b| b.len() / 4).unwrap_or(0);
38    let point_tokens = points.map(|(p, _)| p.len() / 2).unwrap_or(0);
39    let tokens = (box_tokens + point_tokens).max(1);
40    let mut features = vec![0.0; tokens * SAM3_DET_DIM];
41    if let Some(b) = boxes {
42        for (i, chunk) in b.chunks_exact(4).enumerate() {
43            for (j, v) in chunk.iter().enumerate() {
44                features[i * SAM3_DET_DIM + j] = *v;
45            }
46        }
47    }
48    if let Some((coords, labels)) = points {
49        let base = box_tokens;
50        for (i, xy) in coords.chunks_exact(2).enumerate() {
51            let row = base + i;
52            if row >= tokens {
53                break;
54            }
55            features[row * SAM3_DET_DIM] = xy[0];
56            features[row * SAM3_DET_DIM + 1] = xy[1];
57            features[row * SAM3_DET_DIM + 2] = labels.get(i).copied().unwrap_or(0.0);
58        }
59    }
60    Sam3GeometryFeatures {
61        features,
62        tokens,
63        dim: SAM3_DET_DIM,
64    }
65}