Skip to main content

rlx_sam3/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-0 SAM3 detector encoder/decoder flow.
17
18use anyhow::Result;
19use rlx_flow::{BuiltModel, ModelFlow, plugin_named};
20use rlx_ir::hir::{HirModule, HirMut};
21use rlx_ir::{DType, Shape};
22use rlx_runtime::Device;
23use std::collections::HashMap;
24
25use super::detector_decoder::Sam3DecoderWeights;
26use super::detector_decoder_ir::Sam3CompiledDecoder;
27use super::detector_encoder::{N_LAYERS, Sam3EncoderWeights};
28use super::detector_encoder_ir::emit_sam3_detector_encoder_layer;
29use rlx_core::flow_util::built_from_hir_with_profile;
30use rlx_flow::CompileProfile;
31
32const SAM3_SRC: &str = "sam3.encoder.src";
33const SAM3_SRC_POS: &str = "sam3.encoder.src_pos";
34const SAM3_PROMPT: &str = "sam3.encoder.prompt";
35const SAM3_PROMPT_KPM: &str = "sam3.encoder.prompt_kpm_inv";
36
37const D_MODEL: usize = 256;
38
39#[derive(Clone)]
40pub struct Sam3DetectorEncoderFlow<'a> {
41    weights: &'a Sam3EncoderWeights,
42    batch: usize,
43    hw: usize,
44    seq: usize,
45    profile: CompileProfile,
46}
47
48impl<'a> Sam3DetectorEncoderFlow<'a> {
49    pub fn new(weights: &'a Sam3EncoderWeights, batch: usize, hw: usize, seq: usize) -> Self {
50        Self::new_with_profile(weights, batch, hw, seq, CompileProfile::sam3())
51    }
52
53    pub fn new_with_profile(
54        weights: &'a Sam3EncoderWeights,
55        batch: usize,
56        hw: usize,
57        seq: usize,
58        profile: CompileProfile,
59    ) -> Self {
60        Self {
61            weights,
62            batch,
63            hw,
64            seq,
65            profile,
66        }
67    }
68
69    pub fn build(self) -> Result<BuiltModel> {
70        let (hir, params) =
71            build_sam3_detector_encoder_model_flow(self.weights, self.batch, self.hw, self.seq)?;
72        built_from_hir_with_profile(hir, params, self.profile)
73    }
74}
75
76/// Native SAM3 detector encoder via [`ModelFlow`] + [`emit_sam3_detector_encoder_layer`].
77pub fn build_sam3_detector_encoder_model_flow(
78    weights: &Sam3EncoderWeights,
79    batch: usize,
80    hw: usize,
81    seq: usize,
82) -> Result<(HirModule, std::collections::HashMap<String, Vec<f32>>)> {
83    let f = DType::F32;
84    let tgt_shape = Shape::new(&[batch, hw, D_MODEL], f);
85    let weights_c = weights.clone();
86    let bind_out = tgt_shape.clone();
87
88    let mut flow = ModelFlow::new("sam3_detector_encoder")
89        .input("src", tgt_shape.clone())
90        .input("src_pos", tgt_shape.clone())
91        .input("prompt", Shape::new(&[batch, seq, D_MODEL], f))
92        .input("prompt_kpm_inv", Shape::new(&[batch, seq], f));
93
94    flow = flow.plugin_named("sam3.encoder.bind", move |emit, _| {
95        let src = emit.flow_input("src")?.hir_id();
96        emit.set_named(SAM3_SRC, src);
97        emit.set_named(SAM3_SRC_POS, emit.flow_input("src_pos")?.hir_id());
98        emit.set_named(SAM3_PROMPT, emit.flow_input("prompt")?.hir_id());
99        emit.set_named(SAM3_PROMPT_KPM, emit.flow_input("prompt_kpm_inv")?.hir_id());
100        Ok(Some(emit.wrap(src, bind_out.clone())))
101    });
102
103    let weights_layers = weights_c.clone();
104    let layer_count = weights_layers.layers.len().min(N_LAYERS);
105    flow = flow.repeat_layers(layer_count, move |li| {
106        let weights = weights_layers.clone();
107        let out_shape = tgt_shape.clone();
108        plugin_named(format!("sam3.encoder.l{li}"), move |emit, input| {
109            let tgt_in = input.ok_or_else(|| anyhow::anyhow!("sam3 encoder layer requires tgt"))?;
110            let src_pos = emit.named(SAM3_SRC_POS)?;
111            let prompt = emit.named(SAM3_PROMPT)?;
112            let prompt_kpm = emit.named(SAM3_PROMPT_KPM)?;
113            let hir = emit
114                .module
115                .as_hir_mut()
116                .expect("sam3 encoder flow requires HIR stage");
117            let mut gb = HirMut::new(hir);
118            let layer = weights
119                .layers
120                .get(li)
121                .ok_or_else(|| anyhow::anyhow!("sam3 encoder layer {li} missing"))?;
122            let mut typed_params = Vec::new();
123            let mut gguf_w_cache = HashMap::new();
124            let h = emit_sam3_detector_encoder_layer(
125                &mut gb,
126                emit.params,
127                &mut typed_params,
128                &mut gguf_w_cache,
129                None,
130                &weights.prefix,
131                li,
132                layer,
133                batch,
134                hw,
135                seq,
136                tgt_in.hir_id(),
137                src_pos,
138                prompt,
139                prompt_kpm,
140            )?;
141            Ok(Some(emit.wrap(h, out_shape.clone())))
142        })
143    });
144
145    flow = flow.output("tgt");
146
147    struct Sam3EncoderParams;
148    impl rlx_flow::WeightSource for Sam3EncoderParams {
149        fn take(&mut self, _key: &str, _transpose: bool) -> Result<(Vec<f32>, Vec<usize>)> {
150            anyhow::bail!("sam3 encoder flow does not load via WeightSource")
151        }
152    }
153
154    let built = flow.build(&mut Sam3EncoderParams)?;
155    built.into_parts()
156}
157
158pub fn build_sam3_detector_encoder_built(
159    weights: &Sam3EncoderWeights,
160    batch: usize,
161    hw: usize,
162    seq: usize,
163) -> Result<BuiltModel> {
164    Sam3DetectorEncoderFlow::new(weights, batch, hw, seq).build()
165}
166
167pub fn build_sam3_detector_encoder_built_with_profile(
168    weights: &Sam3EncoderWeights,
169    batch: usize,
170    hw: usize,
171    seq: usize,
172    profile: &CompileProfile,
173) -> Result<BuiltModel> {
174    Sam3DetectorEncoderFlow::new_with_profile(weights, batch, hw, seq, profile.clone()).build()
175}
176
177/// Compile-once SAM3 detector decoder (six per-layer graphs + host glue).
178pub struct Sam3DetectorDecoderBuilt {
179    pub inner: Sam3CompiledDecoder,
180}
181
182#[derive(Clone)]
183pub struct Sam3DetectorDecoderFlow<'a> {
184    weights: &'a Sam3DecoderWeights,
185    batch: usize,
186    hw: usize,
187    seq: usize,
188    device: Device,
189    profile: CompileProfile,
190}
191
192impl<'a> Sam3DetectorDecoderFlow<'a> {
193    pub fn new(
194        weights: &'a Sam3DecoderWeights,
195        batch: usize,
196        hw: usize,
197        seq: usize,
198        device: Device,
199    ) -> Self {
200        Self::new_with_profile(weights, batch, hw, seq, device, CompileProfile::sam3())
201    }
202
203    pub fn new_with_profile(
204        weights: &'a Sam3DecoderWeights,
205        batch: usize,
206        hw: usize,
207        seq: usize,
208        device: Device,
209        profile: CompileProfile,
210    ) -> Self {
211        Self {
212            weights,
213            batch,
214            hw,
215            seq,
216            device,
217            profile,
218        }
219    }
220
221    pub fn build(self) -> Result<Sam3DetectorDecoderBuilt> {
222        Ok(Sam3DetectorDecoderBuilt {
223            inner: Sam3CompiledDecoder::new_with_profile(
224                self.weights,
225                self.batch,
226                self.hw,
227                self.seq,
228                self.device,
229                &self.profile,
230            )?,
231        })
232    }
233}
234
235pub fn build_sam3_detector_decoder_built(
236    weights: &Sam3DecoderWeights,
237    batch: usize,
238    hw: usize,
239    seq: usize,
240    device: Device,
241) -> Result<Sam3DetectorDecoderBuilt> {
242    Sam3DetectorDecoderFlow::new(weights, batch, hw, seq, device).build()
243}
244
245pub fn build_sam3_detector_decoder_built_with_profile(
246    weights: &Sam3DecoderWeights,
247    batch: usize,
248    hw: usize,
249    seq: usize,
250    device: Device,
251    profile: &CompileProfile,
252) -> Result<Sam3DetectorDecoderBuilt> {
253    Sam3DetectorDecoderFlow::new_with_profile(weights, batch, hw, seq, device, profile.clone())
254        .build()
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::detector_encoder::{N_LAYERS, Sam3EncoderLayerWeights, Sam3EncoderWeights};
261
262    fn tiny_layer() -> Sam3EncoderLayerWeights {
263        let d = D_MODEL;
264        let ff = 2048usize;
265        Sam3EncoderLayerWeights {
266            self_attn_in_w_t: vec![0.0; d * 3 * d],
267            self_attn_in_b: vec![0.0; 3 * d],
268            self_attn_in_gguf_key: None,
269            self_attn_out_w_t: vec![0.0; d * d],
270            self_attn_out_b: vec![0.0; d],
271            self_attn_out_gguf_key: None,
272            cross_attn_in_w_t: vec![0.0; d * 3 * d],
273            cross_attn_in_b: vec![0.0; 3 * d],
274            cross_attn_in_gguf_key: None,
275            cross_attn_out_w_t: vec![0.0; d * d],
276            cross_attn_out_b: vec![0.0; d],
277            cross_attn_out_gguf_key: None,
278            linear1_w_t: vec![0.0; d * ff],
279            linear1_b: vec![0.0; ff],
280            linear1_gguf_key: None,
281            linear2_w_t: vec![0.0; ff * d],
282            linear2_b: vec![0.0; d],
283            linear2_gguf_key: None,
284            norm1_w: vec![1.0; d],
285            norm1_b: vec![0.0; d],
286            norm2_w: vec![1.0; d],
287            norm2_b: vec![0.0; d],
288            norm3_w: vec![1.0; d],
289            norm3_b: vec![0.0; d],
290        }
291    }
292
293    #[test]
294    fn sam3_encoder_model_flow_matches_hir_node_count() {
295        let weights = Sam3EncoderWeights {
296            loaded: true,
297            prefix: "transformer.encoder".to_string(),
298            layers: vec![tiny_layer(); N_LAYERS],
299        };
300        let (hir_flow, _) = build_sam3_detector_encoder_model_flow(&weights, 1, 8, 4).unwrap();
301        let parts = crate::detector_encoder_ir::build_encoder_hir(&weights, 1, 8, 4, None).unwrap();
302        assert_eq!(hir_flow.len(), parts.hir.len());
303        assert_eq!(hir_flow.outputs.len(), 1);
304    }
305}