Skip to main content

rlx_sam2/
upscale_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM2 mask-decoder upscaling (ConvTranspose2d + LN2d + optional high-res 1×1 fuse).
17
18use super::mask_decoder::Sam2MaskDecoderWeights;
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{conv_transpose2d_stride2_k2_bias, conv2d_bias, layer_norm2d_nchw};
21use rlx_flow::CompileProfile;
22use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
23use rlx_ir::{DType, Graph, HirGraphExt, Shape};
24use rlx_runtime::{CompiledGraph, Device};
25use std::collections::HashMap;
26
27pub struct Sam2MaskUpscaleCompiled {
28    graph: CompiledGraph,
29    e: usize,
30    use_high_res: bool,
31}
32
33impl Sam2MaskUpscaleCompiled {
34    pub fn compile(w: &Sam2MaskDecoderWeights, grid: usize, device: Device) -> Result<Self> {
35        Self::compile_with_profile(w, grid, device, &CompileProfile::sam_encoder())
36    }
37
38    pub fn compile_with_profile(
39        w: &Sam2MaskDecoderWeights,
40        grid: usize,
41        device: Device,
42        profile: &CompileProfile,
43    ) -> Result<Self> {
44        let (graph, params) = build_mask_upscale_graph(w, grid)?;
45        let mut compiled =
46            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
47        for (name, data) in &params {
48            compiled.set_param(name, data);
49        }
50        Ok(Self {
51            graph: compiled,
52            e: w.transformer_dim,
53            use_high_res: w.use_high_res_features,
54        })
55    }
56
57    /// `src_nchw` `[E, g, g]`. When `use_high_res`, pass `feat_s1` `[E, 2g, 2g]`
58    /// and `feat_s0` `[E, 4g, 4g]`; otherwise pass empty slices.
59    pub fn run(
60        &mut self,
61        src_nchw: &[f32],
62        feat_s1: &[f32],
63        feat_s0: &[f32],
64        grid: usize,
65    ) -> Result<Vec<f32>> {
66        let e = self.e;
67        let g = grid;
68        anyhow::ensure!(src_nchw.len() == e * g * g);
69        let mut inputs = vec![("src", src_nchw)];
70        let s1_buf;
71        let s0_buf;
72        if self.use_high_res {
73            let h1 = g * 2;
74            let h2 = g * 4;
75            anyhow::ensure!(feat_s1.len() == e * h1 * h1 && feat_s0.len() == e * h2 * h2);
76            s1_buf = feat_s1;
77            s0_buf = feat_s0;
78            inputs.push(("feat_s1", s1_buf));
79            inputs.push(("feat_s0", s0_buf));
80        }
81        let outs = self
82            .graph
83            .run(&inputs.iter().map(|(n, d)| (*n, *d)).collect::<Vec<_>>());
84        Ok(outs.into_iter().next().expect("sam2 upscale output"))
85    }
86}
87
88pub fn build_mask_upscale_graph(
89    w: &Sam2MaskDecoderWeights,
90    grid: usize,
91) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
92    let e = w.transformer_dim;
93    let g = grid;
94    let q4 = e / 4;
95    let q8 = e / 8;
96    let eps = 1e-6f32;
97    let f = DType::F32;
98
99    let mut hir = HirModule::new("sam2_mask_upscale");
100    let mut params = HashMap::new();
101    let mut hg = HirMut::new(&mut hir);
102
103    let src = hg.input("src", Shape::new(&[1, e, g, g], f));
104
105    let up1_w = p(
106        &mut hg,
107        &mut params,
108        "upscale_conv1_w",
109        w.upscale_conv1_w.clone(),
110        &[e, q4, 2, 2],
111    );
112    let up1_b = p(
113        &mut hg,
114        &mut params,
115        "upscale_conv1_b",
116        w.upscale_conv1_b.clone(),
117        &[q4],
118    );
119    let mut up1 = conv_transpose2d_stride2_k2_bias(&mut hg, src, up1_w, up1_b, 1, q4, g, g);
120
121    if w.use_high_res_features {
122        let h1 = g * 2;
123        let feat_s1 = hg.input("feat_s1", Shape::new(&[1, e, h1, h1], f));
124        let s1_w = p(
125            &mut hg,
126            &mut params,
127            "conv_s1_w",
128            w.conv_s1_w.clone().unwrap(),
129            &[q4, e, 1, 1],
130        );
131        let s1_b = p(
132            &mut hg,
133            &mut params,
134            "conv_s1_b",
135            w.conv_s1_b.clone().unwrap(),
136            &[q4],
137        );
138        let s1_proj = conv2d_bias(
139            &mut hg,
140            feat_s1,
141            s1_w,
142            s1_b,
143            1,
144            q4,
145            1,
146            1,
147            [1, 1],
148            [0, 0],
149            h1,
150            h1,
151        );
152        up1 = hg.add(up1, s1_proj);
153    }
154
155    let ln_g = p(
156        &mut hg,
157        &mut params,
158        "upscale_ln_g",
159        w.upscale_ln_g.clone(),
160        &[q4],
161    );
162    let ln_b = p(
163        &mut hg,
164        &mut params,
165        "upscale_ln_b",
166        w.upscale_ln_b.clone(),
167        &[q4],
168    );
169    up1 = layer_norm2d_nchw(&mut hg, up1, ln_g, ln_b, eps);
170    up1 = hg.gelu(up1);
171
172    let h1 = g * 2;
173    let up2_w = p(
174        &mut hg,
175        &mut params,
176        "upscale_conv2_w",
177        w.upscale_conv2_w.clone(),
178        &[q4, q8, 2, 2],
179    );
180    let up2_b = p(
181        &mut hg,
182        &mut params,
183        "upscale_conv2_b",
184        w.upscale_conv2_b.clone(),
185        &[q8],
186    );
187    let mut up2 = conv_transpose2d_stride2_k2_bias(&mut hg, up1, up2_w, up2_b, 1, q8, h1, h1);
188
189    if w.use_high_res_features {
190        let h2 = g * 4;
191        let feat_s0 = hg.input("feat_s0", Shape::new(&[1, e, h2, h2], f));
192        let s0_w = p(
193            &mut hg,
194            &mut params,
195            "conv_s0_w",
196            w.conv_s0_w.clone().unwrap(),
197            &[q8, e, 1, 1],
198        );
199        let s0_b = p(
200            &mut hg,
201            &mut params,
202            "conv_s0_b",
203            w.conv_s0_b.clone().unwrap(),
204            &[q8],
205        );
206        let s0_proj = conv2d_bias(
207            &mut hg,
208            feat_s0,
209            s0_w,
210            s0_b,
211            1,
212            q8,
213            1,
214            1,
215            [1, 1],
216            [0, 0],
217            h2,
218            h2,
219        );
220        up2 = hg.add(up2, s0_proj);
221    }
222
223    let up2 = hg.gelu(up2);
224    hir.set_outputs(vec![up2]);
225    Graph::from_hir(hir)
226        .map_err(|e| anyhow::anyhow!("{e}"))
227        .map(|g| (g, params))
228}
229
230fn p(
231    g: &mut HirMut<'_>,
232    params: &mut HashMap<String, Vec<f32>>,
233    name: &str,
234    data: Vec<f32>,
235    shape: &[usize],
236) -> HirNodeId {
237    let id = g.param(name, Shape::new(shape, DType::F32));
238    params.insert(name.to_string(), data);
239    id
240}