1use anyhow::Result;
19use rlx_flow::{BuiltModel, ModelFlow, plugin_named};
20use rlx_ir::hir::{HirModule, HirMut};
21use rlx_ir::{DType, Shape};
22use rlx_runtime::Device;
23use std::collections::HashMap;
24
25use super::detector_decoder::Sam3DecoderWeights;
26use super::detector_decoder_ir::Sam3CompiledDecoder;
27use super::detector_encoder::{N_LAYERS, Sam3EncoderWeights};
28use super::detector_encoder_ir::emit_sam3_detector_encoder_layer;
29use rlx_core::flow_util::built_from_hir_with_profile;
30use rlx_flow::CompileProfile;
31
32const SAM3_SRC: &str = "sam3.encoder.src";
33const SAM3_SRC_POS: &str = "sam3.encoder.src_pos";
34const SAM3_PROMPT: &str = "sam3.encoder.prompt";
35const SAM3_PROMPT_KPM: &str = "sam3.encoder.prompt_kpm_inv";
36
37const D_MODEL: usize = 256;
38
39#[derive(Clone)]
40pub struct Sam3DetectorEncoderFlow<'a> {
41 weights: &'a Sam3EncoderWeights,
42 batch: usize,
43 hw: usize,
44 seq: usize,
45 profile: CompileProfile,
46}
47
48impl<'a> Sam3DetectorEncoderFlow<'a> {
49 pub fn new(weights: &'a Sam3EncoderWeights, batch: usize, hw: usize, seq: usize) -> Self {
50 Self::new_with_profile(weights, batch, hw, seq, CompileProfile::sam3())
51 }
52
53 pub fn new_with_profile(
54 weights: &'a Sam3EncoderWeights,
55 batch: usize,
56 hw: usize,
57 seq: usize,
58 profile: CompileProfile,
59 ) -> Self {
60 Self {
61 weights,
62 batch,
63 hw,
64 seq,
65 profile,
66 }
67 }
68
69 pub fn build(self) -> Result<BuiltModel> {
70 let (hir, params) =
71 build_sam3_detector_encoder_model_flow(self.weights, self.batch, self.hw, self.seq)?;
72 built_from_hir_with_profile(hir, params, self.profile)
73 }
74}
75
76pub fn build_sam3_detector_encoder_model_flow(
78 weights: &Sam3EncoderWeights,
79 batch: usize,
80 hw: usize,
81 seq: usize,
82) -> Result<(HirModule, std::collections::HashMap<String, Vec<f32>>)> {
83 let f = DType::F32;
84 let tgt_shape = Shape::new(&[batch, hw, D_MODEL], f);
85 let weights_c = weights.clone();
86 let bind_out = tgt_shape.clone();
87
88 let mut flow = ModelFlow::new("sam3_detector_encoder")
89 .input("src", tgt_shape.clone())
90 .input("src_pos", tgt_shape.clone())
91 .input("prompt", Shape::new(&[batch, seq, D_MODEL], f))
92 .input("prompt_kpm_inv", Shape::new(&[batch, seq], f));
93
94 flow = flow.plugin_named("sam3.encoder.bind", move |emit, _| {
95 let src = emit.flow_input("src")?.hir_id();
96 emit.set_named(SAM3_SRC, src);
97 emit.set_named(SAM3_SRC_POS, emit.flow_input("src_pos")?.hir_id());
98 emit.set_named(SAM3_PROMPT, emit.flow_input("prompt")?.hir_id());
99 emit.set_named(SAM3_PROMPT_KPM, emit.flow_input("prompt_kpm_inv")?.hir_id());
100 Ok(Some(emit.wrap(src, bind_out.clone())))
101 });
102
103 let weights_layers = weights_c.clone();
104 let layer_count = weights_layers.layers.len().min(N_LAYERS);
105 flow = flow.repeat_layers(layer_count, move |li| {
106 let weights = weights_layers.clone();
107 let out_shape = tgt_shape.clone();
108 plugin_named(format!("sam3.encoder.l{li}"), move |emit, input| {
109 let tgt_in = input.ok_or_else(|| anyhow::anyhow!("sam3 encoder layer requires tgt"))?;
110 let src_pos = emit.named(SAM3_SRC_POS)?;
111 let prompt = emit.named(SAM3_PROMPT)?;
112 let prompt_kpm = emit.named(SAM3_PROMPT_KPM)?;
113 let hir = emit
114 .module
115 .as_hir_mut()
116 .expect("sam3 encoder flow requires HIR stage");
117 let mut gb = HirMut::new(hir);
118 let layer = weights
119 .layers
120 .get(li)
121 .ok_or_else(|| anyhow::anyhow!("sam3 encoder layer {li} missing"))?;
122 let mut typed_params = Vec::new();
123 let mut gguf_w_cache = HashMap::new();
124 let h = emit_sam3_detector_encoder_layer(
125 &mut gb,
126 emit.params,
127 &mut typed_params,
128 &mut gguf_w_cache,
129 None,
130 &weights.prefix,
131 li,
132 layer,
133 batch,
134 hw,
135 seq,
136 tgt_in.hir_id(),
137 src_pos,
138 prompt,
139 prompt_kpm,
140 )?;
141 Ok(Some(emit.wrap(h, out_shape.clone())))
142 })
143 });
144
145 flow = flow.output("tgt");
146
147 struct Sam3EncoderParams;
148 impl rlx_flow::WeightSource for Sam3EncoderParams {
149 fn take(&mut self, _key: &str, _transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
150 anyhow::bail!("sam3 encoder flow does not load via WeightSource")
151 }
152 }
153
154 let built = flow.build(&mut Sam3EncoderParams)?;
155 built.into_parts()
156}
157
158pub fn build_sam3_detector_encoder_built(
159 weights: &Sam3EncoderWeights,
160 batch: usize,
161 hw: usize,
162 seq: usize,
163) -> Result<BuiltModel> {
164 Sam3DetectorEncoderFlow::new(weights, batch, hw, seq).build()
165}
166
167pub fn build_sam3_detector_encoder_built_with_profile(
168 weights: &Sam3EncoderWeights,
169 batch: usize,
170 hw: usize,
171 seq: usize,
172 profile: &CompileProfile,
173) -> Result<BuiltModel> {
174 Sam3DetectorEncoderFlow::new_with_profile(weights, batch, hw, seq, profile.clone()).build()
175}
176
177pub struct Sam3DetectorDecoderBuilt {
179 pub inner: Sam3CompiledDecoder,
180}
181
182#[derive(Clone)]
183pub struct Sam3DetectorDecoderFlow<'a> {
184 weights: &'a Sam3DecoderWeights,
185 batch: usize,
186 hw: usize,
187 seq: usize,
188 device: Device,
189 profile: CompileProfile,
190}
191
192impl<'a> Sam3DetectorDecoderFlow<'a> {
193 pub fn new(
194 weights: &'a Sam3DecoderWeights,
195 batch: usize,
196 hw: usize,
197 seq: usize,
198 device: Device,
199 ) -> Self {
200 Self::new_with_profile(weights, batch, hw, seq, device, CompileProfile::sam3())
201 }
202
203 pub fn new_with_profile(
204 weights: &'a Sam3DecoderWeights,
205 batch: usize,
206 hw: usize,
207 seq: usize,
208 device: Device,
209 profile: CompileProfile,
210 ) -> Self {
211 Self {
212 weights,
213 batch,
214 hw,
215 seq,
216 device,
217 profile,
218 }
219 }
220
221 pub fn build(self) -> Result<Sam3DetectorDecoderBuilt> {
222 Ok(Sam3DetectorDecoderBuilt {
223 inner: Sam3CompiledDecoder::new_with_profile(
224 self.weights,
225 self.batch,
226 self.hw,
227 self.seq,
228 self.device,
229 &self.profile,
230 )?,
231 })
232 }
233}
234
235pub fn build_sam3_detector_decoder_built(
236 weights: &Sam3DecoderWeights,
237 batch: usize,
238 hw: usize,
239 seq: usize,
240 device: Device,
241) -> Result<Sam3DetectorDecoderBuilt> {
242 Sam3DetectorDecoderFlow::new(weights, batch, hw, seq, device).build()
243}
244
245pub fn build_sam3_detector_decoder_built_with_profile(
246 weights: &Sam3DecoderWeights,
247 batch: usize,
248 hw: usize,
249 seq: usize,
250 device: Device,
251 profile: &CompileProfile,
252) -> Result<Sam3DetectorDecoderBuilt> {
253 Sam3DetectorDecoderFlow::new_with_profile(weights, batch, hw, seq, device, profile.clone())
254 .build()
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::detector_encoder::{N_LAYERS, Sam3EncoderLayerWeights, Sam3EncoderWeights};
261
262 fn tiny_layer() -> Sam3EncoderLayerWeights {
263 let d = D_MODEL;
264 let ff = 2048usize;
265 Sam3EncoderLayerWeights {
266 self_attn_in_w_t: vec![0.0; d * 3 * d],
267 self_attn_in_b: vec![0.0; 3 * d],
268 self_attn_in_gguf_key: None,
269 self_attn_out_w_t: vec![0.0; d * d],
270 self_attn_out_b: vec![0.0; d],
271 self_attn_out_gguf_key: None,
272 cross_attn_in_w_t: vec![0.0; d * 3 * d],
273 cross_attn_in_b: vec![0.0; 3 * d],
274 cross_attn_in_gguf_key: None,
275 cross_attn_out_w_t: vec![0.0; d * d],
276 cross_attn_out_b: vec![0.0; d],
277 cross_attn_out_gguf_key: None,
278 linear1_w_t: vec![0.0; d * ff],
279 linear1_b: vec![0.0; ff],
280 linear1_gguf_key: None,
281 linear2_w_t: vec![0.0; ff * d],
282 linear2_b: vec![0.0; d],
283 linear2_gguf_key: None,
284 norm1_w: vec![1.0; d],
285 norm1_b: vec![0.0; d],
286 norm2_w: vec![1.0; d],
287 norm2_b: vec![0.0; d],
288 norm3_w: vec![1.0; d],
289 norm3_b: vec![0.0; d],
290 }
291 }
292
293 #[test]
294 fn sam3_encoder_model_flow_matches_hir_node_count() {
295 let weights = Sam3EncoderWeights {
296 loaded: true,
297 prefix: "transformer.encoder".to_string(),
298 layers: vec![tiny_layer(); N_LAYERS],
299 };
300 let (hir_flow, _) = build_sam3_detector_encoder_model_flow(&weights, 1, 8, 4).unwrap();
301 let parts = crate::detector_encoder_ir::build_encoder_hir(&weights, 1, 8, 4, None).unwrap();
302 assert_eq!(hir_flow.len(), parts.hir.len());
303 assert_eq!(hir_flow.outputs.len(), 1);
304 }
305}