Skip to main content

rlx_sam3/
segmentation_pixel_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM3 pixel-decoder fusion steps + instance/semantic 1×1 heads (IR).
17
18use crate::gguf_ir::{linear_gguf_bias, packed_linear_for_key};
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{conv2d_bias, nchw_shape};
21use rlx_flow::{CompileProfile, GgufPackedParams};
22use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
23use rlx_ir::{DType, Graph, HirGraphExt, Shape};
24use rlx_runtime::{CompiledGraph, Device};
25use std::collections::HashMap;
26
27const D_MODEL: usize = 256;
28const GN_GROUPS: usize = 8;
29
30type ConvGraphParts = (
31    Graph,
32    HashMap<String, Vec<f32>>,
33    Vec<(String, Vec<u8>, DType)>,
34);
35
36/// One pixel-decoder layer: upsample `prev` 2×, add `curr`, conv3×3, GN, ReLU.
37pub struct Sam3PixelDecoderStepCompiled {
38    graph: CompiledGraph,
39    pub out_h: usize,
40    pub out_w: usize,
41}
42
43impl Sam3PixelDecoderStepCompiled {
44    pub fn compile(
45        prev_h: usize,
46        prev_w: usize,
47        out_h: usize,
48        out_w: usize,
49        conv_w: &[f32],
50        conv_b: &[f32],
51        gn_w: &[f32],
52        gn_b: &[f32],
53        device: Device,
54    ) -> Result<Self> {
55        Self::compile_with_profile(
56            prev_h,
57            prev_w,
58            out_h,
59            out_w,
60            conv_w,
61            conv_b,
62            gn_w,
63            gn_b,
64            device,
65            &CompileProfile::sam3(),
66        )
67    }
68
69    pub fn compile_with_profile(
70        prev_h: usize,
71        prev_w: usize,
72        out_h: usize,
73        out_w: usize,
74        conv_w: &[f32],
75        conv_b: &[f32],
76        gn_w: &[f32],
77        gn_b: &[f32],
78        device: Device,
79        profile: &CompileProfile,
80    ) -> Result<Self> {
81        anyhow::ensure!(
82            out_h == prev_h * 2 && out_w == prev_w * 2,
83            "pixel_decoder step expects 2× upsample {prev_h}×{prev_w} → {out_h}×{out_w}"
84        );
85        let (graph, params) =
86            build_pixel_step_graph(prev_h, prev_w, out_h, out_w, conv_w, conv_b, gn_w, gn_b)?;
87        let mut compiled =
88            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
89        for (name, data) in &params {
90            compiled.set_param(name, data);
91        }
92        Ok(Self {
93            graph: compiled,
94            out_h,
95            out_w,
96        })
97    }
98
99    pub fn run(&mut self, prev: &[f32], curr: &[f32]) -> Result<Vec<f32>> {
100        let n = D_MODEL * self.out_h * self.out_w;
101        anyhow::ensure!(prev.len() == n / 4 && curr.len() == n);
102        let outs = self.graph.run(&[("prev", prev), ("curr", curr)]);
103        Ok(outs.into_iter().next().expect("pixel_decoder step output"))
104    }
105}
106
107pub struct Sam3Conv1x1Compiled {
108    graph: CompiledGraph,
109    pub out_c: usize,
110    pub h: usize,
111    pub w: usize,
112}
113
114impl Sam3Conv1x1Compiled {
115    pub fn compile(
116        in_c: usize,
117        out_c: usize,
118        h: usize,
119        w: usize,
120        weight: &[f32],
121        bias: &[f32],
122        device: Device,
123    ) -> Result<Self> {
124        Self::compile_with_profile(
125            in_c,
126            out_c,
127            h,
128            w,
129            weight,
130            bias,
131            device,
132            &CompileProfile::sam3(),
133        )
134    }
135
136    pub fn compile_with_profile(
137        in_c: usize,
138        out_c: usize,
139        h: usize,
140        w: usize,
141        weight: &[f32],
142        bias: &[f32],
143        device: Device,
144        profile: &CompileProfile,
145    ) -> Result<Self> {
146        let (graph, params, typed) = build_conv1x1_graph(in_c, out_c, h, w, weight, bias)?;
147        let mut compiled =
148            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
149        for (name, data) in &params {
150            compiled.set_param(name, data);
151        }
152        rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
153        Ok(Self {
154            graph: compiled,
155            out_c,
156            h,
157            w,
158        })
159    }
160
161    /// F32 conv when materialized, or `DequantMatMul` when `gguf_key` is set.
162    pub fn compile_with_gguf(
163        in_c: usize,
164        out_c: usize,
165        h: usize,
166        w: usize,
167        weight: &[f32],
168        bias: &[f32],
169        gguf_key: Option<&str>,
170        gguf_packed: &GgufPackedParams,
171        device: Device,
172        profile: &CompileProfile,
173    ) -> Result<Self> {
174        let (graph, params, typed) = if let (Some(key), Some(p)) = (
175            gguf_key,
176            gguf_key.and_then(|k| packed_linear_for_key(Some(gguf_packed), k)),
177        ) {
178            build_conv1x1_graph_gguf(in_c, out_c, h, w, p, bias, key)?
179        } else {
180            anyhow::ensure!(
181                !weight.is_empty(),
182                "conv1x1: missing F32 weights and GGUF key"
183            );
184            build_conv1x1_graph(in_c, out_c, h, w, weight, bias)?
185        };
186        let mut compiled =
187            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
188        for (name, data) in &params {
189            compiled.set_param(name, data);
190        }
191        rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
192        Ok(Self {
193            graph: compiled,
194            out_c,
195            h,
196            w,
197        })
198    }
199
200    pub fn run(&mut self, x: &[f32]) -> Result<Vec<f32>> {
201        anyhow::ensure!(x.len() == D_MODEL * self.h * self.w);
202        let outs = self.graph.run(&[("x", x)]);
203        Ok(outs.into_iter().next().expect("conv1x1 output"))
204    }
205}
206
207fn build_pixel_step_graph(
208    prev_h: usize,
209    prev_w: usize,
210    out_h: usize,
211    out_w: usize,
212    conv_w: &[f32],
213    conv_b: &[f32],
214    gn_w: &[f32],
215    gn_b: &[f32],
216) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
217    let f = DType::F32;
218    let mut hir = HirModule::new("sam3_pixel_decoder_step");
219    let mut params = HashMap::new();
220    let mut g = HirMut::new(&mut hir);
221
222    let prev = g.input("prev", Shape::new(&[1, D_MODEL, prev_h, prev_w], f));
223    let curr = g.input("curr", Shape::new(&[1, D_MODEL, out_h, out_w], f));
224
225    let up = g.resize_nearest_2x(prev);
226    let combined = g.add(curr, up);
227
228    let cw = param_f32(
229        &mut g,
230        &mut params,
231        "conv_w",
232        conv_w,
233        &[D_MODEL, D_MODEL, 3, 3],
234    );
235    let cb = param_f32(&mut g, &mut params, "conv_b", conv_b, &[D_MODEL]);
236    let mut y = conv2d_bias(
237        &mut g,
238        combined,
239        cw,
240        cb,
241        1,
242        D_MODEL,
243        3,
244        3,
245        [1, 1],
246        [1, 1],
247        out_h,
248        out_w,
249    );
250
251    let gnw = param_f32(&mut g, &mut params, "gn_w", gn_w, &[D_MODEL]);
252    let gnb = param_f32(&mut g, &mut params, "gn_b", gn_b, &[D_MODEL]);
253    y = g.group_norm(y, gnw, gnb, GN_GROUPS, 1e-5);
254    let out = g.relu(y);
255
256    hir.set_outputs(vec![out]);
257    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
258    Ok((graph, params))
259}
260
261fn build_conv1x1_graph(
262    in_c: usize,
263    out_c: usize,
264    h: usize,
265    w: usize,
266    weight: &[f32],
267    bias: &[f32],
268) -> Result<ConvGraphParts> {
269    let f = DType::F32;
270    let mut hir = HirModule::new("sam3_conv1x1");
271    let mut params = HashMap::new();
272    let mut g = HirMut::new(&mut hir);
273
274    let x = g.input("x", nchw_shape(1, in_c, h, w, f));
275    let wt = param_f32(&mut g, &mut params, "w", weight, &[out_c, in_c, 1, 1]);
276    let bt = param_f32(&mut g, &mut params, "b", bias, &[out_c]);
277    let y = conv2d_bias(&mut g, x, wt, bt, 1, out_c, 1, 1, [1, 1], [0, 0], h, w);
278
279    hir.set_outputs(vec![y]);
280    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
281    Ok((graph, params, Vec::new()))
282}
283
284fn build_conv1x1_graph_gguf(
285    in_c: usize,
286    out_c: usize,
287    h: usize,
288    w: usize,
289    p: &rlx_flow::GgufPackedLinear,
290    bias: &[f32],
291    gguf_key: &str,
292) -> Result<ConvGraphParts> {
293    let f = DType::F32;
294    let mut hir = HirModule::new("sam3_conv1x1_gguf");
295    let mut params = HashMap::new();
296    let mut typed = Vec::new();
297    let mut gguf_cache = HashMap::new();
298    let mut g = HirMut::new(&mut hir);
299
300    let x = g.input("x", nchw_shape(1, in_c, h, w, f));
301    let spatial = (h * w) as i64;
302    let flat = g.reshape_(x, vec![1, spatial, in_c as i64]);
303    let stem = gguf_key.strip_suffix(".weight").unwrap_or(gguf_key);
304    let y_flat = linear_gguf_bias(
305        &mut g,
306        &mut params,
307        &mut typed,
308        &mut gguf_cache,
309        stem,
310        p,
311        flat,
312        bias,
313        in_c,
314        out_c,
315    )?;
316    let y = g.reshape_(y_flat, vec![1, out_c as i64, h as i64, w as i64]);
317
318    hir.set_outputs(vec![y]);
319    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
320    Ok((graph, params, typed))
321}
322
323fn param_f32(
324    g: &mut HirMut<'_>,
325    params: &mut HashMap<String, Vec<f32>>,
326    name: &str,
327    data: &[f32],
328    shape: &[usize],
329) -> HirNodeId {
330    let id = g.param(name, Shape::new(shape, DType::F32));
331    params.insert(name.to_string(), data.to_vec());
332    id
333}
334
335/// Compile pixel-decoder steps for SAM3 base (neck scales 4× / 2× on 72×72 trunk).
336pub fn compile_pixel_decoder_steps(
337    pixel_conv_w: &[Vec<f32>],
338    pixel_conv_b: &[Vec<f32>],
339    pixel_gn_w: &[Vec<f32>],
340    pixel_gn_b: &[Vec<f32>],
341    trunk_grid: usize,
342    device: Device,
343    profile: &CompileProfile,
344) -> Result<Vec<Sam3PixelDecoderStepCompiled>> {
345    // After scalp=1: FPN levels are 288×288, 144×144, 72×72 (fine→coarse indices 0,1,2).
346    // pop finest 72; fuse 144 then 288.
347    let g0 = trunk_grid;
348    let g1 = trunk_grid * 2;
349    let g2 = trunk_grid * 4;
350    let steps = [(g0, g0, g1, g1, 0usize), (g1, g1, g2, g2, 1usize)];
351    steps
352        .iter()
353        .map(|(ph, pw, oh, ow, i)| {
354            Sam3PixelDecoderStepCompiled::compile_with_profile(
355                *ph,
356                *pw,
357                *oh,
358                *ow,
359                &pixel_conv_w[*i],
360                &pixel_conv_b[*i],
361                &pixel_gn_w[*i],
362                &pixel_gn_b[*i],
363                device,
364                profile,
365            )
366        })
367        .collect()
368}