1use crate::gguf_ir::{linear_gguf_bias, packed_linear_for_key};
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{conv2d_bias, nchw_shape};
21use rlx_flow::{CompileProfile, GgufPackedParams};
22use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
23use rlx_ir::{DType, Graph, HirGraphExt, Shape};
24use rlx_runtime::{CompiledGraph, Device};
25use std::collections::HashMap;
26
27const D_MODEL: usize = 256;
28const GN_GROUPS: usize = 8;
29
30type ConvGraphParts = (
31 Graph,
32 HashMap<String, Vec<f32>>,
33 Vec<(String, Vec<u8>, DType)>,
34);
35
36pub struct Sam3PixelDecoderStepCompiled {
38 graph: CompiledGraph,
39 pub out_h: usize,
40 pub out_w: usize,
41}
42
43impl Sam3PixelDecoderStepCompiled {
44 pub fn compile(
45 prev_h: usize,
46 prev_w: usize,
47 out_h: usize,
48 out_w: usize,
49 conv_w: &[f32],
50 conv_b: &[f32],
51 gn_w: &[f32],
52 gn_b: &[f32],
53 device: Device,
54 ) -> Result<Self> {
55 Self::compile_with_profile(
56 prev_h,
57 prev_w,
58 out_h,
59 out_w,
60 conv_w,
61 conv_b,
62 gn_w,
63 gn_b,
64 device,
65 &CompileProfile::sam3(),
66 )
67 }
68
69 pub fn compile_with_profile(
70 prev_h: usize,
71 prev_w: usize,
72 out_h: usize,
73 out_w: usize,
74 conv_w: &[f32],
75 conv_b: &[f32],
76 gn_w: &[f32],
77 gn_b: &[f32],
78 device: Device,
79 profile: &CompileProfile,
80 ) -> Result<Self> {
81 anyhow::ensure!(
82 out_h == prev_h * 2 && out_w == prev_w * 2,
83 "pixel_decoder step expects 2× upsample {prev_h}×{prev_w} → {out_h}×{out_w}"
84 );
85 let (graph, params) =
86 build_pixel_step_graph(prev_h, prev_w, out_h, out_w, conv_w, conv_b, gn_w, gn_b)?;
87 let mut compiled =
88 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
89 for (name, data) in ¶ms {
90 compiled.set_param(name, data);
91 }
92 Ok(Self {
93 graph: compiled,
94 out_h,
95 out_w,
96 })
97 }
98
99 pub fn run(&mut self, prev: &[f32], curr: &[f32]) -> Result<Vec<f32>> {
100 let n = D_MODEL * self.out_h * self.out_w;
101 anyhow::ensure!(prev.len() == n / 4 && curr.len() == n);
102 let outs = self.graph.run(&[("prev", prev), ("curr", curr)]);
103 Ok(outs.into_iter().next().expect("pixel_decoder step output"))
104 }
105}
106
107pub struct Sam3Conv1x1Compiled {
108 graph: CompiledGraph,
109 pub out_c: usize,
110 pub h: usize,
111 pub w: usize,
112}
113
114impl Sam3Conv1x1Compiled {
115 pub fn compile(
116 in_c: usize,
117 out_c: usize,
118 h: usize,
119 w: usize,
120 weight: &[f32],
121 bias: &[f32],
122 device: Device,
123 ) -> Result<Self> {
124 Self::compile_with_profile(
125 in_c,
126 out_c,
127 h,
128 w,
129 weight,
130 bias,
131 device,
132 &CompileProfile::sam3(),
133 )
134 }
135
136 pub fn compile_with_profile(
137 in_c: usize,
138 out_c: usize,
139 h: usize,
140 w: usize,
141 weight: &[f32],
142 bias: &[f32],
143 device: Device,
144 profile: &CompileProfile,
145 ) -> Result<Self> {
146 let (graph, params, typed) = build_conv1x1_graph(in_c, out_c, h, w, weight, bias)?;
147 let mut compiled =
148 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
149 for (name, data) in ¶ms {
150 compiled.set_param(name, data);
151 }
152 rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
153 Ok(Self {
154 graph: compiled,
155 out_c,
156 h,
157 w,
158 })
159 }
160
161 pub fn compile_with_gguf(
163 in_c: usize,
164 out_c: usize,
165 h: usize,
166 w: usize,
167 weight: &[f32],
168 bias: &[f32],
169 gguf_key: Option<&str>,
170 gguf_packed: &GgufPackedParams,
171 device: Device,
172 profile: &CompileProfile,
173 ) -> Result<Self> {
174 let (graph, params, typed) = if let (Some(key), Some(p)) = (
175 gguf_key,
176 gguf_key.and_then(|k| packed_linear_for_key(Some(gguf_packed), k)),
177 ) {
178 build_conv1x1_graph_gguf(in_c, out_c, h, w, p, bias, key)?
179 } else {
180 anyhow::ensure!(
181 !weight.is_empty(),
182 "conv1x1: missing F32 weights and GGUF key"
183 );
184 build_conv1x1_graph(in_c, out_c, h, w, weight, bias)?
185 };
186 let mut compiled =
187 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
188 for (name, data) in ¶ms {
189 compiled.set_param(name, data);
190 }
191 rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
192 Ok(Self {
193 graph: compiled,
194 out_c,
195 h,
196 w,
197 })
198 }
199
200 pub fn run(&mut self, x: &[f32]) -> Result<Vec<f32>> {
201 anyhow::ensure!(x.len() == D_MODEL * self.h * self.w);
202 let outs = self.graph.run(&[("x", x)]);
203 Ok(outs.into_iter().next().expect("conv1x1 output"))
204 }
205}
206
207fn build_pixel_step_graph(
208 prev_h: usize,
209 prev_w: usize,
210 out_h: usize,
211 out_w: usize,
212 conv_w: &[f32],
213 conv_b: &[f32],
214 gn_w: &[f32],
215 gn_b: &[f32],
216) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
217 let f = DType::F32;
218 let mut hir = HirModule::new("sam3_pixel_decoder_step");
219 let mut params = HashMap::new();
220 let mut g = HirMut::new(&mut hir);
221
222 let prev = g.input("prev", Shape::new(&[1, D_MODEL, prev_h, prev_w], f));
223 let curr = g.input("curr", Shape::new(&[1, D_MODEL, out_h, out_w], f));
224
225 let up = g.resize_nearest_2x(prev);
226 let combined = g.add(curr, up);
227
228 let cw = param_f32(
229 &mut g,
230 &mut params,
231 "conv_w",
232 conv_w,
233 &[D_MODEL, D_MODEL, 3, 3],
234 );
235 let cb = param_f32(&mut g, &mut params, "conv_b", conv_b, &[D_MODEL]);
236 let mut y = conv2d_bias(
237 &mut g,
238 combined,
239 cw,
240 cb,
241 1,
242 D_MODEL,
243 3,
244 3,
245 [1, 1],
246 [1, 1],
247 out_h,
248 out_w,
249 );
250
251 let gnw = param_f32(&mut g, &mut params, "gn_w", gn_w, &[D_MODEL]);
252 let gnb = param_f32(&mut g, &mut params, "gn_b", gn_b, &[D_MODEL]);
253 y = g.group_norm(y, gnw, gnb, GN_GROUPS, 1e-5);
254 let out = g.relu(y);
255
256 hir.set_outputs(vec![out]);
257 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
258 Ok((graph, params))
259}
260
261fn build_conv1x1_graph(
262 in_c: usize,
263 out_c: usize,
264 h: usize,
265 w: usize,
266 weight: &[f32],
267 bias: &[f32],
268) -> Result<ConvGraphParts> {
269 let f = DType::F32;
270 let mut hir = HirModule::new("sam3_conv1x1");
271 let mut params = HashMap::new();
272 let mut g = HirMut::new(&mut hir);
273
274 let x = g.input("x", nchw_shape(1, in_c, h, w, f));
275 let wt = param_f32(&mut g, &mut params, "w", weight, &[out_c, in_c, 1, 1]);
276 let bt = param_f32(&mut g, &mut params, "b", bias, &[out_c]);
277 let y = conv2d_bias(&mut g, x, wt, bt, 1, out_c, 1, 1, [1, 1], [0, 0], h, w);
278
279 hir.set_outputs(vec![y]);
280 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
281 Ok((graph, params, Vec::new()))
282}
283
284fn build_conv1x1_graph_gguf(
285 in_c: usize,
286 out_c: usize,
287 h: usize,
288 w: usize,
289 p: &rlx_flow::GgufPackedLinear,
290 bias: &[f32],
291 gguf_key: &str,
292) -> Result<ConvGraphParts> {
293 let f = DType::F32;
294 let mut hir = HirModule::new("sam3_conv1x1_gguf");
295 let mut params = HashMap::new();
296 let mut typed = Vec::new();
297 let mut gguf_cache = HashMap::new();
298 let mut g = HirMut::new(&mut hir);
299
300 let x = g.input("x", nchw_shape(1, in_c, h, w, f));
301 let spatial = (h * w) as i64;
302 let flat = g.reshape_(x, vec![1, spatial, in_c as i64]);
303 let stem = gguf_key.strip_suffix(".weight").unwrap_or(gguf_key);
304 let y_flat = linear_gguf_bias(
305 &mut g,
306 &mut params,
307 &mut typed,
308 &mut gguf_cache,
309 stem,
310 p,
311 flat,
312 bias,
313 in_c,
314 out_c,
315 )?;
316 let y = g.reshape_(y_flat, vec![1, out_c as i64, h as i64, w as i64]);
317
318 hir.set_outputs(vec![y]);
319 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
320 Ok((graph, params, typed))
321}
322
323fn param_f32(
324 g: &mut HirMut<'_>,
325 params: &mut HashMap<String, Vec<f32>>,
326 name: &str,
327 data: &[f32],
328 shape: &[usize],
329) -> HirNodeId {
330 let id = g.param(name, Shape::new(shape, DType::F32));
331 params.insert(name.to_string(), data.to_vec());
332 id
333}
334
335pub fn compile_pixel_decoder_steps(
337 pixel_conv_w: &[Vec<f32>],
338 pixel_conv_b: &[Vec<f32>],
339 pixel_gn_w: &[Vec<f32>],
340 pixel_gn_b: &[Vec<f32>],
341 trunk_grid: usize,
342 device: Device,
343 profile: &CompileProfile,
344) -> Result<Vec<Sam3PixelDecoderStepCompiled>> {
345 let g0 = trunk_grid;
348 let g1 = trunk_grid * 2;
349 let g2 = trunk_grid * 4;
350 let steps = [(g0, g0, g1, g1, 0usize), (g1, g1, g2, g2, 1usize)];
351 steps
352 .iter()
353 .map(|(ph, pw, oh, ow, i)| {
354 Sam3PixelDecoderStepCompiled::compile_with_profile(
355 *ph,
356 *pw,
357 *oh,
358 *ow,
359 &pixel_conv_w[*i],
360 &pixel_conv_b[*i],
361 &pixel_gn_w[*i],
362 &pixel_gn_b[*i],
363 device,
364 profile,
365 )
366 })
367 .collect()
368}