1use super::memory_encoder::{Sam2CXBlockWeights, Sam2FuserWeights, Sam2MaskDownSamplerWeights};
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{conv2d_bias, conv2d_bias_groups, layer_norm2d_nchw, nchw_shape};
21use rlx_flow::CompileProfile;
22use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
23use rlx_ir::op::Op;
24use rlx_ir::{DType, Graph, HirGraphExt, Shape};
25use rlx_runtime::{CompiledGraph, Device};
26use std::collections::HashMap;
27
28const LN_EPS: f32 = 1e-6;
29
30pub struct Sam2MemoryMaskDownCompiled {
31 graph: CompiledGraph,
32 pub embed_dim: usize,
33 pub in_h: usize,
34 pub in_w: usize,
35 pub out_h: usize,
36 pub out_w: usize,
37}
38
39pub struct Sam2MemoryPrefixCompiled {
41 graph: CompiledGraph,
42 pub in_dim: usize,
43 pub mask_in_h: usize,
44 pub mask_in_w: usize,
45 pub feat_h: usize,
46 pub feat_w: usize,
47}
48
49impl Sam2MemoryPrefixCompiled {
50 pub fn compile(
51 mask_ds: &Sam2MaskDownSamplerWeights,
52 in_dim: usize,
53 mask_in_h: usize,
54 mask_in_w: usize,
55 feat_h: usize,
56 feat_w: usize,
57 pix_w: &[f32],
58 pix_b: &[f32],
59 device: Device,
60 ) -> Result<Self> {
61 Self::compile_with_profile(
62 mask_ds,
63 in_dim,
64 mask_in_h,
65 mask_in_w,
66 feat_h,
67 feat_w,
68 pix_w,
69 pix_b,
70 device,
71 &CompileProfile::sam_encoder(),
72 )
73 }
74
75 pub fn compile_with_profile(
76 mask_ds: &Sam2MaskDownSamplerWeights,
77 in_dim: usize,
78 mask_in_h: usize,
79 mask_in_w: usize,
80 feat_h: usize,
81 feat_w: usize,
82 pix_w: &[f32],
83 pix_b: &[f32],
84 device: Device,
85 profile: &CompileProfile,
86 ) -> Result<Self> {
87 let (graph, params) = build_prefix_graph(
88 mask_ds, in_dim, mask_in_h, mask_in_w, feat_h, feat_w, pix_w, pix_b,
89 )?;
90 let mut compiled =
91 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
92 for (name, data) in ¶ms {
93 compiled.set_param(name, data);
94 }
95 Ok(Self {
96 graph: compiled,
97 in_dim,
98 mask_in_h,
99 mask_in_w,
100 feat_h,
101 feat_w,
102 })
103 }
104
105 pub fn run(&mut self, mask: &[f32], pix_feat: &[f32]) -> Result<Vec<f32>> {
107 anyhow::ensure!(
108 mask.len() == self.mask_in_h * self.mask_in_w,
109 "prefix mask len {} ≠ {}",
110 mask.len(),
111 self.mask_in_h * self.mask_in_w
112 );
113 anyhow::ensure!(
114 pix_feat.len() == self.in_dim * self.feat_h * self.feat_w,
115 "prefix pix_feat len {} ≠ {}",
116 pix_feat.len(),
117 self.in_dim * self.feat_h * self.feat_w
118 );
119 let outs = self.graph.run(&[("mask", mask), ("pix", pix_feat)]);
120 Ok(outs.into_iter().next().expect("memory prefix output"))
121 }
122}
123
124impl Sam2MemoryMaskDownCompiled {
125 pub fn compile(
126 w: &Sam2MaskDownSamplerWeights,
127 in_h: usize,
128 in_w: usize,
129 device: Device,
130 ) -> Result<Self> {
131 Self::compile_with_profile(w, in_h, in_w, device, &CompileProfile::sam_encoder())
132 }
133
134 pub fn compile_with_profile(
135 w: &Sam2MaskDownSamplerWeights,
136 in_h: usize,
137 in_w: usize,
138 device: Device,
139 profile: &CompileProfile,
140 ) -> Result<Self> {
141 let (graph, params, out_h, out_w) = build_mask_downsampler_graph(w, in_h, in_w)?;
142 let mut compiled =
143 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
144 for (name, data) in ¶ms {
145 compiled.set_param(name, data);
146 }
147 Ok(Self {
148 graph: compiled,
149 embed_dim: w.embed_dim,
150 in_h,
151 in_w,
152 out_h,
153 out_w,
154 })
155 }
156
157 pub fn run(&mut self, mask: &[f32]) -> Result<Vec<f32>> {
159 let expected = self.in_h * self.in_w;
160 anyhow::ensure!(
161 mask.len() == expected,
162 "mask len {} ≠ {expected} (1×{}×{})",
163 mask.len(),
164 self.in_h,
165 self.in_w
166 );
167 let outs = self.graph.run(&[("mask", mask)]);
168 Ok(outs.into_iter().next().expect("memory mask_down output"))
169 }
170}
171
172#[allow(clippy::type_complexity)]
173fn build_mask_downsampler_graph(
174 w: &Sam2MaskDownSamplerWeights,
175 in_h: usize,
176 in_w: usize,
177) -> Result<(Graph, HashMap<String, Vec<f32>>, usize, usize)> {
178 let f = DType::F32;
179 let mut hir = HirModule::new("sam2_memory_mask_down");
180 let mut params = HashMap::new();
181 let mut g = HirMut::new(&mut hir);
182
183 let x = g.input("mask", Shape::new(&[1, 1, in_h, in_w], f));
184 let (out, out_h, out_w) = append_mask_downsampler(&mut g, &mut params, x, w, in_h, in_w, "")?;
185
186 hir.set_outputs(vec![out]);
187 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
188 Ok((graph, params, out_h, out_w))
189}
190
191fn build_prefix_graph(
192 mask_ds: &Sam2MaskDownSamplerWeights,
193 in_dim: usize,
194 mask_in_h: usize,
195 mask_in_w: usize,
196 feat_h: usize,
197 feat_w: usize,
198 pix_w: &[f32],
199 pix_b: &[f32],
200) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
201 let f = DType::F32;
202 let mut hir = HirModule::new("sam2_memory_prefix");
203 let mut params = HashMap::new();
204 let mut g = HirMut::new(&mut hir);
205
206 let mask = g.input("mask", Shape::new(&[1, 1, mask_in_h, mask_in_w], f));
207 let (m_down, down_h, down_w) = append_mask_downsampler(
208 &mut g,
209 &mut params,
210 mask,
211 mask_ds,
212 mask_in_h,
213 mask_in_w,
214 "md_",
215 )?;
216 anyhow::ensure!(
217 down_h == feat_h && down_w == feat_w,
218 "mask down {down_h}×{down_w} ≠ pix {feat_h}×{feat_w}"
219 );
220
221 let pix = g.input("pix", nchw_shape(1, in_dim, feat_h, feat_w, f));
222 let pp_w = param(&mut g, &mut params, "pp_w", pix_w, &[in_dim, in_dim, 1, 1]);
223 let pp_b = param(&mut g, &mut params, "pp_b", pix_b, &[in_dim]);
224 let pix_y = conv2d_bias(
225 &mut g,
226 pix,
227 pp_w,
228 pp_b,
229 1,
230 in_dim,
231 1,
232 1,
233 [1, 1],
234 [0, 0],
235 feat_h,
236 feat_w,
237 );
238 let out = g.add(pix_y, m_down);
239
240 hir.set_outputs(vec![out]);
241 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
242 Ok((graph, params))
243}
244
245fn append_mask_downsampler(
247 g: &mut HirMut<'_>,
248 params: &mut HashMap<String, Vec<f32>>,
249 mut x: HirNodeId,
250 w: &Sam2MaskDownSamplerWeights,
251 in_h: usize,
252 in_w: usize,
253 pfx: &str,
254) -> Result<(HirNodeId, usize, usize)> {
255 let mut cur_h = in_h;
256 let mut cur_w = in_w;
257
258 for (li, level) in w.levels.iter().enumerate() {
259 let out_h = (cur_h + 2 * w.padding - w.kernel) / w.stride + 1;
260 let out_w = (cur_w + 2 * w.padding - w.kernel) / w.stride + 1;
261 let k = w.kernel;
262 let cw = param(
263 g,
264 params,
265 &format!("{pfx}conv{li}_w"),
266 &level.conv_w,
267 &[level.out_c, level.in_c, k, k],
268 );
269 let cb = param(
270 g,
271 params,
272 &format!("{pfx}conv{li}_b"),
273 &level.conv_b,
274 &[level.out_c],
275 );
276 x = conv2d_bias(
277 g,
278 x,
279 cw,
280 cb,
281 1,
282 level.out_c,
283 k,
284 k,
285 [w.stride, w.stride],
286 [w.padding, w.padding],
287 out_h,
288 out_w,
289 );
290 let ln_g = param(
291 g,
292 params,
293 &format!("{pfx}ln{li}_g"),
294 &level.ln_g,
295 &[level.out_c],
296 );
297 let ln_b = param(
298 g,
299 params,
300 &format!("{pfx}ln{li}_b"),
301 &level.ln_b,
302 &[level.out_c],
303 );
304 x = layer_norm2d_nchw(g, x, ln_g, ln_b, LN_EPS);
305 x = g.gelu(x);
306 cur_h = out_h;
307 cur_w = out_w;
308 }
309
310 let last_c = w.levels.last().map(|l| l.out_c).unwrap_or(1);
311 let fw = param(
312 g,
313 params,
314 &format!("{pfx}final_w"),
315 &w.final_conv_w,
316 &[w.embed_dim, last_c, 1, 1],
317 );
318 let fb = param(
319 g,
320 params,
321 &format!("{pfx}final_b"),
322 &w.final_conv_b,
323 &[w.embed_dim],
324 );
325 let out = conv2d_bias(
326 g,
327 x,
328 fw,
329 fb,
330 1,
331 w.embed_dim,
332 1,
333 1,
334 [1, 1],
335 [0, 0],
336 cur_h,
337 cur_w,
338 );
339 Ok((out, cur_h, cur_w))
340}
341
342fn param(
343 g: &mut HirMut<'_>,
344 params: &mut HashMap<String, Vec<f32>>,
345 name: &str,
346 data: &[f32],
347 shape: &[usize],
348) -> HirNodeId {
349 let id = g.param(name, Shape::new(shape, DType::F32));
350 params.insert(name.to_string(), data.to_vec());
351 id
352}
353
354pub struct Sam2MemoryConv1x1Compiled {
356 graph: CompiledGraph,
357 in_c: usize,
358 pub out_c: usize,
359 pub h: usize,
360 pub w: usize,
361}
362
363impl Sam2MemoryConv1x1Compiled {
364 pub fn compile(
365 in_c: usize,
366 out_c: usize,
367 h: usize,
368 w: usize,
369 weight: &[f32],
370 bias: &[f32],
371 device: Device,
372 ) -> Result<Self> {
373 Self::compile_with_profile(
374 in_c,
375 out_c,
376 h,
377 w,
378 weight,
379 bias,
380 device,
381 &CompileProfile::sam_encoder(),
382 )
383 }
384
385 pub fn compile_with_profile(
386 in_c: usize,
387 out_c: usize,
388 h: usize,
389 w: usize,
390 weight: &[f32],
391 bias: &[f32],
392 device: Device,
393 profile: &CompileProfile,
394 ) -> Result<Self> {
395 let (graph, params) = build_conv1x1_graph(in_c, out_c, h, w, weight, bias)?;
396 let mut compiled =
397 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
398 for (name, data) in ¶ms {
399 compiled.set_param(name, data);
400 }
401 Ok(Self {
402 graph: compiled,
403 in_c,
404 out_c,
405 h,
406 w,
407 })
408 }
409
410 pub fn run(&mut self, x: &[f32]) -> Result<Vec<f32>> {
411 let expected = self.in_c * self.h * self.w;
412 anyhow::ensure!(
413 x.len() == expected,
414 "conv1x1 input len {} ≠ {} ({}×{}×{})",
415 x.len(),
416 expected,
417 self.in_c,
418 self.h,
419 self.w
420 );
421 let outs = self.graph.run(&[("x", x)]);
422 Ok(outs.into_iter().next().expect("conv1x1 output"))
423 }
424}
425
426fn build_conv1x1_graph(
427 in_c: usize,
428 out_c: usize,
429 h: usize,
430 w: usize,
431 weight: &[f32],
432 bias: &[f32],
433) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
434 let f = DType::F32;
435 let mut hir = HirModule::new("sam2_conv1x1");
436 let mut params = HashMap::new();
437 let mut g = HirMut::new(&mut hir);
438
439 let x = g.input("x", nchw_shape(1, in_c, h, w, f));
440 let wt = param(&mut g, &mut params, "w", weight, &[out_c, in_c, 1, 1]);
441 let bt = param(&mut g, &mut params, "b", bias, &[out_c]);
442 let y = conv2d_bias(&mut g, x, wt, bt, 1, out_c, 1, 1, [1, 1], [0, 0], h, w);
443
444 hir.set_outputs(vec![y]);
445 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
446 Ok((graph, params))
447}
448
449pub struct Sam2MemoryFuserCompiled {
451 graph: CompiledGraph,
452 pub dim: usize,
453 pub h: usize,
454 pub w: usize,
455}
456
457impl Sam2MemoryFuserCompiled {
458 pub fn compile(w: &Sam2FuserWeights, h: usize, ww: usize, device: Device) -> Result<Self> {
459 Self::compile_with_profile(w, h, ww, device, &CompileProfile::sam_encoder())
460 }
461
462 pub fn compile_with_profile(
463 w: &Sam2FuserWeights,
464 h: usize,
465 ww: usize,
466 device: Device,
467 profile: &CompileProfile,
468 ) -> Result<Self> {
469 let (graph, params) = build_fuser_graph(w, h, ww)?;
470 let mut compiled =
471 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
472 for (name, data) in ¶ms {
473 compiled.set_param(name, data);
474 }
475 Ok(Self {
476 graph: compiled,
477 dim: w.dim,
478 h,
479 w: ww,
480 })
481 }
482
483 pub fn run(&mut self, x: &[f32]) -> Result<Vec<f32>> {
484 let expected = self.dim * self.h * self.w;
485 anyhow::ensure!(
486 x.len() == expected,
487 "fuser input len {} ≠ {expected}",
488 x.len()
489 );
490 let outs = self.graph.run(&[("x", x)]);
491 Ok(outs.into_iter().next().expect("fuser output"))
492 }
493}
494
495fn build_fuser_graph(
496 w: &Sam2FuserWeights,
497 h: usize,
498 ww: usize,
499) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
500 let f = DType::F32;
501 let dim = w.dim;
502 let mut hir = HirModule::new("sam2_memory_fuser");
503 let mut params = HashMap::new();
504 let mut g = HirMut::new(&mut hir);
505
506 let mut x = g.input("x", nchw_shape(1, dim, h, ww, f));
507
508 if let (Some(pw), Some(pb)) = (&w.input_proj_w, &w.input_proj_b) {
509 let wt = param(&mut g, &mut params, "input_proj_w", pw, &[dim, dim, 1, 1]);
510 let bt = param(&mut g, &mut params, "input_proj_b", pb, &[dim]);
511 x = conv2d_bias(&mut g, x, wt, bt, 1, dim, 1, 1, [1, 1], [0, 0], h, ww);
512 }
513
514 for (li, layer) in w.layers.iter().enumerate() {
515 x = cx_block_hir(&mut g, &mut params, x, layer, li, h, ww)?;
516 }
517
518 hir.set_outputs(vec![x]);
519 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
520 Ok((graph, params))
521}
522
523fn cx_block_hir(
524 g: &mut HirMut<'_>,
525 params: &mut HashMap<String, Vec<f32>>,
526 x: HirNodeId,
527 w: &Sam2CXBlockWeights,
528 li: usize,
529 h: usize,
530 ww: usize,
531) -> Result<HirNodeId> {
532 let dim = w.dim;
533 let k = w.kernel;
534 let p = w.padding;
535 let residual = x;
536
537 let dw_w = param(
538 g,
539 params,
540 &format!("l{li}_dw_w"),
541 &w.dw_conv_w,
542 &[dim, 1, k, k],
543 );
544 let dw_b = param(g, params, &format!("l{li}_dw_b"), &w.dw_conv_b, &[dim]);
545 let mut y = conv2d_bias_groups(g, x, dw_w, dw_b, 1, dim, k, k, [1, 1], [p, p], dim, h, ww);
546
547 let ln_g = param(g, params, &format!("l{li}_ln_g"), &w.ln_g, &[dim]);
548 let ln_b = param(g, params, &format!("l{li}_ln_b"), &w.ln_b, &[dim]);
549 y = layer_norm2d_nchw(g, y, ln_g, ln_b, LN_EPS);
550
551 let pw1_w = param(
552 g,
553 params,
554 &format!("l{li}_pw1_w"),
555 &w.pw1_w,
556 &[4 * dim, dim, 1, 1],
557 );
558 let pw1_b = param(g, params, &format!("l{li}_pw1_b"), &w.pw1_b, &[4 * dim]);
559 y = conv2d_bias(g, y, pw1_w, pw1_b, 1, 4 * dim, 1, 1, [1, 1], [0, 0], h, ww);
560 y = g.gelu(y);
561
562 let pw2_w = param(
563 g,
564 params,
565 &format!("l{li}_pw2_w"),
566 &w.pw2_w,
567 &[dim, 4 * dim, 1, 1],
568 );
569 let pw2_b = param(g, params, &format!("l{li}_pw2_b"), &w.pw2_b, &[dim]);
570 y = conv2d_bias(g, y, pw2_w, pw2_b, 1, dim, 1, 1, [1, 1], [0, 0], h, ww);
571
572 if let Some(gamma) = &w.gamma {
573 let gparam = param(g, params, &format!("l{li}_gamma"), gamma, &[dim]);
574 let out_shape = g.shape(y).clone();
575 let g4 = g.reshape_(gparam, vec![1, dim as i64, 1, 1]);
576 let scaled = g.add_node(
577 Op::Expand {
578 target_shape: vec![1, dim as i64, h as i64, ww as i64],
579 },
580 vec![g4],
581 out_shape.clone(),
582 );
583 y = g.mul(y, scaled);
584 }
585
586 Ok(g.add(residual, y))
587}