Skip to main content

rlx_sam/
upscale_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 mask-decoder upscaling subgraph (ConvTranspose2d + LN2d + GELU).
17
18use super::config::SAM_EMBED_HW;
19use super::mask_decoder::MaskDecoderWeights;
20use anyhow::Result;
21use rlx_core::vision_ops_ir::{conv_transpose2d_stride2_k2_bias, layer_norm2d_nchw};
22use rlx_flow::CompileProfile;
23use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
24use rlx_ir::{DType, Graph, HirGraphExt, Shape};
25use rlx_runtime::{CompiledGraph, Device};
26use std::collections::HashMap;
27
28/// Compiled upscale stack: `src_nchw` → `up2` after two transposed convs.
29pub struct SamMaskUpscaleCompiled {
30    graph: CompiledGraph,
31    e: usize,
32    hw: usize,
33}
34
35impl SamMaskUpscaleCompiled {
36    pub fn compile(w: &MaskDecoderWeights, device: Device) -> Result<Self> {
37        Self::compile_with_profile(w, device, &CompileProfile::sam_encoder())
38    }
39
40    pub fn compile_with_profile(
41        w: &MaskDecoderWeights,
42        device: Device,
43        profile: &CompileProfile,
44    ) -> Result<Self> {
45        let (graph, params) = build_mask_upscale_graph(w)?;
46        let mut compiled =
47            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
48        for (name, data) in &params {
49            compiled.set_param(name, data);
50        }
51        Ok(Self {
52            graph: compiled,
53            e: w.transformer_dim,
54            hw: SAM_EMBED_HW,
55        })
56    }
57
58    /// `src_nchw` is `[E, hw, hw]` NCHW (same layout as mask decoder).
59    /// Returns `up2` `[E/8, 4·hw, 4·hw]`.
60    pub fn run(&mut self, src_nchw: &[f32]) -> Result<Vec<f32>> {
61        let e = self.e;
62        let hw = self.hw;
63        anyhow::ensure!(
64            src_nchw.len() == e * hw * hw,
65            "src_nchw len {} ≠ E·hw·hw",
66            src_nchw.len()
67        );
68        let outs = self.graph.run(&[("src", src_nchw)]);
69        Ok(outs.into_iter().next().expect("upscale output"))
70    }
71}
72
73pub fn build_mask_upscale_hir(
74    w: &MaskDecoderWeights,
75) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
76    let e = w.transformer_dim;
77    let hw = SAM_EMBED_HW;
78    let q4 = e / 4;
79    let q8 = e / 8;
80    let eps = 1e-6f32;
81    let f = DType::F32;
82
83    let mut hir = HirModule::new("sam_mask_upscale");
84    let mut params = HashMap::new();
85    let mut g = HirMut::new(&mut hir);
86
87    let src = g.input("src", Shape::new(&[1, e, hw, hw], f));
88
89    let up1_w = param(
90        &mut g,
91        &mut params,
92        "upscale_conv1_w",
93        w.upscale_conv1_w.clone(),
94        &[e, q4, 2, 2],
95    );
96    let up1_b = param(
97        &mut g,
98        &mut params,
99        "upscale_conv1_b",
100        w.upscale_conv1_b.clone(),
101        &[q4],
102    );
103    let mut up1 = conv_transpose2d_stride2_k2_bias(&mut g, src, up1_w, up1_b, 1, q4, hw, hw);
104
105    let ln_g = param(
106        &mut g,
107        &mut params,
108        "upscale_ln_g",
109        w.upscale_ln_g.clone(),
110        &[q4],
111    );
112    let ln_b = param(
113        &mut g,
114        &mut params,
115        "upscale_ln_b",
116        w.upscale_ln_b.clone(),
117        &[q4],
118    );
119    up1 = layer_norm2d_nchw(&mut g, up1, ln_g, ln_b, eps);
120    up1 = g.gelu(up1);
121
122    let h1 = hw * 2;
123    let up2_w = param(
124        &mut g,
125        &mut params,
126        "upscale_conv2_w",
127        w.upscale_conv2_w.clone(),
128        &[q4, q8, 2, 2],
129    );
130    let up2_b = param(
131        &mut g,
132        &mut params,
133        "upscale_conv2_b",
134        w.upscale_conv2_b.clone(),
135        &[q8],
136    );
137    let up2 = conv_transpose2d_stride2_k2_bias(&mut g, up1, up2_w, up2_b, 1, q8, h1, h1);
138    let up2 = g.gelu(up2);
139
140    hir.set_outputs(vec![up2]);
141    Ok((hir, params))
142}
143
144pub fn build_mask_upscale_graph(
145    w: &MaskDecoderWeights,
146) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
147    let (hir, params) = build_mask_upscale_hir(w)?;
148    Graph::from_hir(hir)
149        .map_err(|e| anyhow::anyhow!("{e}"))
150        .map(|g| (g, params))
151}
152
153fn param(
154    g: &mut HirMut<'_>,
155    params: &mut HashMap<String, Vec<f32>>,
156    name: &str,
157    data: Vec<f32>,
158    shape: &[usize],
159) -> HirNodeId {
160    let id = g.param(name, Shape::new(shape, DType::F32));
161    params.insert(name.to_string(), data);
162    id
163}