1use super::config::SAM_EMBED_HW;
19use super::mask_decoder::MaskDecoderWeights;
20use anyhow::Result;
21use rlx_core::vision_ops_ir::{conv_transpose2d_stride2_k2_bias, layer_norm2d_nchw};
22use rlx_flow::CompileProfile;
23use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
24use rlx_ir::{DType, Graph, HirGraphExt, Shape};
25use rlx_runtime::{CompiledGraph, Device};
26use std::collections::HashMap;
27
28pub struct SamMaskUpscaleCompiled {
30 graph: CompiledGraph,
31 e: usize,
32 hw: usize,
33}
34
35impl SamMaskUpscaleCompiled {
36 pub fn compile(w: &MaskDecoderWeights, device: Device) -> Result<Self> {
37 Self::compile_with_profile(w, device, &CompileProfile::sam_encoder())
38 }
39
40 pub fn compile_with_profile(
41 w: &MaskDecoderWeights,
42 device: Device,
43 profile: &CompileProfile,
44 ) -> Result<Self> {
45 let (graph, params) = build_mask_upscale_graph(w)?;
46 let mut compiled =
47 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
48 for (name, data) in ¶ms {
49 compiled.set_param(name, data);
50 }
51 Ok(Self {
52 graph: compiled,
53 e: w.transformer_dim,
54 hw: SAM_EMBED_HW,
55 })
56 }
57
58 pub fn run(&mut self, src_nchw: &[f32]) -> Result<Vec<f32>> {
61 let e = self.e;
62 let hw = self.hw;
63 anyhow::ensure!(
64 src_nchw.len() == e * hw * hw,
65 "src_nchw len {} ≠ E·hw·hw",
66 src_nchw.len()
67 );
68 let outs = self.graph.run(&[("src", src_nchw)]);
69 Ok(outs.into_iter().next().expect("upscale output"))
70 }
71}
72
73pub fn build_mask_upscale_hir(
74 w: &MaskDecoderWeights,
75) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
76 let e = w.transformer_dim;
77 let hw = SAM_EMBED_HW;
78 let q4 = e / 4;
79 let q8 = e / 8;
80 let eps = 1e-6f32;
81 let f = DType::F32;
82
83 let mut hir = HirModule::new("sam_mask_upscale");
84 let mut params = HashMap::new();
85 let mut g = HirMut::new(&mut hir);
86
87 let src = g.input("src", Shape::new(&[1, e, hw, hw], f));
88
89 let up1_w = param(
90 &mut g,
91 &mut params,
92 "upscale_conv1_w",
93 w.upscale_conv1_w.clone(),
94 &[e, q4, 2, 2],
95 );
96 let up1_b = param(
97 &mut g,
98 &mut params,
99 "upscale_conv1_b",
100 w.upscale_conv1_b.clone(),
101 &[q4],
102 );
103 let mut up1 = conv_transpose2d_stride2_k2_bias(&mut g, src, up1_w, up1_b, 1, q4, hw, hw);
104
105 let ln_g = param(
106 &mut g,
107 &mut params,
108 "upscale_ln_g",
109 w.upscale_ln_g.clone(),
110 &[q4],
111 );
112 let ln_b = param(
113 &mut g,
114 &mut params,
115 "upscale_ln_b",
116 w.upscale_ln_b.clone(),
117 &[q4],
118 );
119 up1 = layer_norm2d_nchw(&mut g, up1, ln_g, ln_b, eps);
120 up1 = g.gelu(up1);
121
122 let h1 = hw * 2;
123 let up2_w = param(
124 &mut g,
125 &mut params,
126 "upscale_conv2_w",
127 w.upscale_conv2_w.clone(),
128 &[q4, q8, 2, 2],
129 );
130 let up2_b = param(
131 &mut g,
132 &mut params,
133 "upscale_conv2_b",
134 w.upscale_conv2_b.clone(),
135 &[q8],
136 );
137 let up2 = conv_transpose2d_stride2_k2_bias(&mut g, up1, up2_w, up2_b, 1, q8, h1, h1);
138 let up2 = g.gelu(up2);
139
140 hir.set_outputs(vec![up2]);
141 Ok((hir, params))
142}
143
144pub fn build_mask_upscale_graph(
145 w: &MaskDecoderWeights,
146) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
147 let (hir, params) = build_mask_upscale_hir(w)?;
148 Graph::from_hir(hir)
149 .map_err(|e| anyhow::anyhow!("{e}"))
150 .map(|g| (g, params))
151}
152
153fn param(
154 g: &mut HirMut<'_>,
155 params: &mut HashMap<String, Vec<f32>>,
156 name: &str,
157 data: Vec<f32>,
158 shape: &[usize],
159) -> HirNodeId {
160 let id = g.param(name, Shape::new(shape, DType::F32));
161 params.insert(name.to_string(), data);
162 id
163}