1use super::mask_decoder::Sam2MaskDecoderWeights;
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{conv_transpose2d_stride2_k2_bias, conv2d_bias, layer_norm2d_nchw};
21use rlx_flow::CompileProfile;
22use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
23use rlx_ir::{DType, Graph, HirGraphExt, Shape};
24use rlx_runtime::{CompiledGraph, Device};
25use std::collections::HashMap;
26
27pub struct Sam2MaskUpscaleCompiled {
28 graph: CompiledGraph,
29 e: usize,
30 use_high_res: bool,
31}
32
33impl Sam2MaskUpscaleCompiled {
34 pub fn compile(w: &Sam2MaskDecoderWeights, grid: usize, device: Device) -> Result<Self> {
35 Self::compile_with_profile(w, grid, device, &CompileProfile::sam_encoder())
36 }
37
38 pub fn compile_with_profile(
39 w: &Sam2MaskDecoderWeights,
40 grid: usize,
41 device: Device,
42 profile: &CompileProfile,
43 ) -> Result<Self> {
44 let (graph, params) = build_mask_upscale_graph(w, grid)?;
45 let mut compiled =
46 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
47 for (name, data) in ¶ms {
48 compiled.set_param(name, data);
49 }
50 Ok(Self {
51 graph: compiled,
52 e: w.transformer_dim,
53 use_high_res: w.use_high_res_features,
54 })
55 }
56
57 pub fn run(
60 &mut self,
61 src_nchw: &[f32],
62 feat_s1: &[f32],
63 feat_s0: &[f32],
64 grid: usize,
65 ) -> Result<Vec<f32>> {
66 let e = self.e;
67 let g = grid;
68 anyhow::ensure!(src_nchw.len() == e * g * g);
69 let mut inputs = vec![("src", src_nchw)];
70 let s1_buf;
71 let s0_buf;
72 if self.use_high_res {
73 let h1 = g * 2;
74 let h2 = g * 4;
75 anyhow::ensure!(feat_s1.len() == e * h1 * h1 && feat_s0.len() == e * h2 * h2);
76 s1_buf = feat_s1;
77 s0_buf = feat_s0;
78 inputs.push(("feat_s1", s1_buf));
79 inputs.push(("feat_s0", s0_buf));
80 }
81 let outs = self
82 .graph
83 .run(&inputs.iter().map(|(n, d)| (*n, *d)).collect::<Vec<_>>());
84 Ok(outs.into_iter().next().expect("sam2 upscale output"))
85 }
86}
87
88pub fn build_mask_upscale_graph(
89 w: &Sam2MaskDecoderWeights,
90 grid: usize,
91) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
92 let e = w.transformer_dim;
93 let g = grid;
94 let q4 = e / 4;
95 let q8 = e / 8;
96 let eps = 1e-6f32;
97 let f = DType::F32;
98
99 let mut hir = HirModule::new("sam2_mask_upscale");
100 let mut params = HashMap::new();
101 let mut hg = HirMut::new(&mut hir);
102
103 let src = hg.input("src", Shape::new(&[1, e, g, g], f));
104
105 let up1_w = p(
106 &mut hg,
107 &mut params,
108 "upscale_conv1_w",
109 w.upscale_conv1_w.clone(),
110 &[e, q4, 2, 2],
111 );
112 let up1_b = p(
113 &mut hg,
114 &mut params,
115 "upscale_conv1_b",
116 w.upscale_conv1_b.clone(),
117 &[q4],
118 );
119 let mut up1 = conv_transpose2d_stride2_k2_bias(&mut hg, src, up1_w, up1_b, 1, q4, g, g);
120
121 if w.use_high_res_features {
122 let h1 = g * 2;
123 let feat_s1 = hg.input("feat_s1", Shape::new(&[1, e, h1, h1], f));
124 let s1_w = p(
125 &mut hg,
126 &mut params,
127 "conv_s1_w",
128 w.conv_s1_w.clone().unwrap(),
129 &[q4, e, 1, 1],
130 );
131 let s1_b = p(
132 &mut hg,
133 &mut params,
134 "conv_s1_b",
135 w.conv_s1_b.clone().unwrap(),
136 &[q4],
137 );
138 let s1_proj = conv2d_bias(
139 &mut hg,
140 feat_s1,
141 s1_w,
142 s1_b,
143 1,
144 q4,
145 1,
146 1,
147 [1, 1],
148 [0, 0],
149 h1,
150 h1,
151 );
152 up1 = hg.add(up1, s1_proj);
153 }
154
155 let ln_g = p(
156 &mut hg,
157 &mut params,
158 "upscale_ln_g",
159 w.upscale_ln_g.clone(),
160 &[q4],
161 );
162 let ln_b = p(
163 &mut hg,
164 &mut params,
165 "upscale_ln_b",
166 w.upscale_ln_b.clone(),
167 &[q4],
168 );
169 up1 = layer_norm2d_nchw(&mut hg, up1, ln_g, ln_b, eps);
170 up1 = hg.gelu(up1);
171
172 let h1 = g * 2;
173 let up2_w = p(
174 &mut hg,
175 &mut params,
176 "upscale_conv2_w",
177 w.upscale_conv2_w.clone(),
178 &[q4, q8, 2, 2],
179 );
180 let up2_b = p(
181 &mut hg,
182 &mut params,
183 "upscale_conv2_b",
184 w.upscale_conv2_b.clone(),
185 &[q8],
186 );
187 let mut up2 = conv_transpose2d_stride2_k2_bias(&mut hg, up1, up2_w, up2_b, 1, q8, h1, h1);
188
189 if w.use_high_res_features {
190 let h2 = g * 4;
191 let feat_s0 = hg.input("feat_s0", Shape::new(&[1, e, h2, h2], f));
192 let s0_w = p(
193 &mut hg,
194 &mut params,
195 "conv_s0_w",
196 w.conv_s0_w.clone().unwrap(),
197 &[q8, e, 1, 1],
198 );
199 let s0_b = p(
200 &mut hg,
201 &mut params,
202 "conv_s0_b",
203 w.conv_s0_b.clone().unwrap(),
204 &[q8],
205 );
206 let s0_proj = conv2d_bias(
207 &mut hg,
208 feat_s0,
209 s0_w,
210 s0_b,
211 1,
212 q8,
213 1,
214 1,
215 [1, 1],
216 [0, 0],
217 h2,
218 h2,
219 );
220 up2 = hg.add(up2, s0_proj);
221 }
222
223 let up2 = hg.gelu(up2);
224 hir.set_outputs(vec![up2]);
225 Graph::from_hir(hir)
226 .map_err(|e| anyhow::anyhow!("{e}"))
227 .map(|g| (g, params))
228}
229
230fn p(
231 g: &mut HirMut<'_>,
232 params: &mut HashMap<String, Vec<f32>>,
233 name: &str,
234 data: Vec<f32>,
235 shape: &[usize],
236) -> HirNodeId {
237 let id = g.param(name, Shape::new(shape, DType::F32));
238 params.insert(name.to_string(), data);
239 id
240}