Skip to main content

rlx_sam2/
memory_mask_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM2 memory-encoder IR subgraphs (`MaskDownSampler`, prefix fuse, `Fuser`).
17
18use super::memory_encoder::{Sam2CXBlockWeights, Sam2FuserWeights, Sam2MaskDownSamplerWeights};
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{conv2d_bias, conv2d_bias_groups, layer_norm2d_nchw, nchw_shape};
21use rlx_flow::CompileProfile;
22use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
23use rlx_ir::op::Op;
24use rlx_ir::{DType, Graph, HirGraphExt, Shape};
25use rlx_runtime::{CompiledGraph, Device};
26use std::collections::HashMap;
27
28const LN_EPS: f32 = 1e-6;
29
30pub struct Sam2MemoryMaskDownCompiled {
31    graph: CompiledGraph,
32    pub embed_dim: usize,
33    pub in_h: usize,
34    pub in_w: usize,
35    pub out_h: usize,
36    pub out_w: usize,
37}
38
39/// Fused mask down + `pix_feat_proj` + add → fuser input.
40pub struct Sam2MemoryPrefixCompiled {
41    graph: CompiledGraph,
42    pub in_dim: usize,
43    pub mask_in_h: usize,
44    pub mask_in_w: usize,
45    pub feat_h: usize,
46    pub feat_w: usize,
47}
48
49impl Sam2MemoryPrefixCompiled {
50    pub fn compile(
51        mask_ds: &Sam2MaskDownSamplerWeights,
52        in_dim: usize,
53        mask_in_h: usize,
54        mask_in_w: usize,
55        feat_h: usize,
56        feat_w: usize,
57        pix_w: &[f32],
58        pix_b: &[f32],
59        device: Device,
60    ) -> Result<Self> {
61        Self::compile_with_profile(
62            mask_ds,
63            in_dim,
64            mask_in_h,
65            mask_in_w,
66            feat_h,
67            feat_w,
68            pix_w,
69            pix_b,
70            device,
71            &CompileProfile::sam_encoder(),
72        )
73    }
74
75    pub fn compile_with_profile(
76        mask_ds: &Sam2MaskDownSamplerWeights,
77        in_dim: usize,
78        mask_in_h: usize,
79        mask_in_w: usize,
80        feat_h: usize,
81        feat_w: usize,
82        pix_w: &[f32],
83        pix_b: &[f32],
84        device: Device,
85        profile: &CompileProfile,
86    ) -> Result<Self> {
87        let (graph, params) = build_prefix_graph(
88            mask_ds, in_dim, mask_in_h, mask_in_w, feat_h, feat_w, pix_w, pix_b,
89        )?;
90        let mut compiled =
91            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
92        for (name, data) in &params {
93            compiled.set_param(name, data);
94        }
95        Ok(Self {
96            graph: compiled,
97            in_dim,
98            mask_in_h,
99            mask_in_w,
100            feat_h,
101            feat_w,
102        })
103    }
104
105    /// Post-sigmoid mask `[1, H, W]` flat + `pix_feat` NCHW `[in_dim, feat_h, feat_w]`.
106    pub fn run(&mut self, mask: &[f32], pix_feat: &[f32]) -> Result<Vec<f32>> {
107        anyhow::ensure!(
108            mask.len() == self.mask_in_h * self.mask_in_w,
109            "prefix mask len {} ≠ {}",
110            mask.len(),
111            self.mask_in_h * self.mask_in_w
112        );
113        anyhow::ensure!(
114            pix_feat.len() == self.in_dim * self.feat_h * self.feat_w,
115            "prefix pix_feat len {} ≠ {}",
116            pix_feat.len(),
117            self.in_dim * self.feat_h * self.feat_w
118        );
119        let outs = self.graph.run(&[("mask", mask), ("pix", pix_feat)]);
120        Ok(outs.into_iter().next().expect("memory prefix output"))
121    }
122}
123
124impl Sam2MemoryMaskDownCompiled {
125    pub fn compile(
126        w: &Sam2MaskDownSamplerWeights,
127        in_h: usize,
128        in_w: usize,
129        device: Device,
130    ) -> Result<Self> {
131        Self::compile_with_profile(w, in_h, in_w, device, &CompileProfile::sam_encoder())
132    }
133
134    pub fn compile_with_profile(
135        w: &Sam2MaskDownSamplerWeights,
136        in_h: usize,
137        in_w: usize,
138        device: Device,
139        profile: &CompileProfile,
140    ) -> Result<Self> {
141        let (graph, params, out_h, out_w) = build_mask_downsampler_graph(w, in_h, in_w)?;
142        let mut compiled =
143            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
144        for (name, data) in &params {
145            compiled.set_param(name, data);
146        }
147        Ok(Self {
148            graph: compiled,
149            embed_dim: w.embed_dim,
150            in_h,
151            in_w,
152            out_h,
153            out_w,
154        })
155    }
156
157    /// Flat `[1, H, W]` mask (post-sigmoid).
158    pub fn run(&mut self, mask: &[f32]) -> Result<Vec<f32>> {
159        let expected = self.in_h * self.in_w;
160        anyhow::ensure!(
161            mask.len() == expected,
162            "mask len {} ≠ {expected} (1×{}×{})",
163            mask.len(),
164            self.in_h,
165            self.in_w
166        );
167        let outs = self.graph.run(&[("mask", mask)]);
168        Ok(outs.into_iter().next().expect("memory mask_down output"))
169    }
170}
171
172#[allow(clippy::type_complexity)]
173fn build_mask_downsampler_graph(
174    w: &Sam2MaskDownSamplerWeights,
175    in_h: usize,
176    in_w: usize,
177) -> Result<(Graph, HashMap<String, Vec<f32>>, usize, usize)> {
178    let f = DType::F32;
179    let mut hir = HirModule::new("sam2_memory_mask_down");
180    let mut params = HashMap::new();
181    let mut g = HirMut::new(&mut hir);
182
183    let x = g.input("mask", Shape::new(&[1, 1, in_h, in_w], f));
184    let (out, out_h, out_w) = append_mask_downsampler(&mut g, &mut params, x, w, in_h, in_w, "")?;
185
186    hir.set_outputs(vec![out]);
187    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
188    Ok((graph, params, out_h, out_w))
189}
190
191fn build_prefix_graph(
192    mask_ds: &Sam2MaskDownSamplerWeights,
193    in_dim: usize,
194    mask_in_h: usize,
195    mask_in_w: usize,
196    feat_h: usize,
197    feat_w: usize,
198    pix_w: &[f32],
199    pix_b: &[f32],
200) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
201    let f = DType::F32;
202    let mut hir = HirModule::new("sam2_memory_prefix");
203    let mut params = HashMap::new();
204    let mut g = HirMut::new(&mut hir);
205
206    let mask = g.input("mask", Shape::new(&[1, 1, mask_in_h, mask_in_w], f));
207    let (m_down, down_h, down_w) = append_mask_downsampler(
208        &mut g,
209        &mut params,
210        mask,
211        mask_ds,
212        mask_in_h,
213        mask_in_w,
214        "md_",
215    )?;
216    anyhow::ensure!(
217        down_h == feat_h && down_w == feat_w,
218        "mask down {down_h}×{down_w} ≠ pix {feat_h}×{feat_w}"
219    );
220
221    let pix = g.input("pix", nchw_shape(1, in_dim, feat_h, feat_w, f));
222    let pp_w = param(&mut g, &mut params, "pp_w", pix_w, &[in_dim, in_dim, 1, 1]);
223    let pp_b = param(&mut g, &mut params, "pp_b", pix_b, &[in_dim]);
224    let pix_y = conv2d_bias(
225        &mut g,
226        pix,
227        pp_w,
228        pp_b,
229        1,
230        in_dim,
231        1,
232        1,
233        [1, 1],
234        [0, 0],
235        feat_h,
236        feat_w,
237    );
238    let out = g.add(pix_y, m_down);
239
240    hir.set_outputs(vec![out]);
241    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
242    Ok((graph, params))
243}
244
245/// Append MaskDownSampler ops; `pfx` prefixes parameter names.
246fn append_mask_downsampler(
247    g: &mut HirMut<'_>,
248    params: &mut HashMap<String, Vec<f32>>,
249    mut x: HirNodeId,
250    w: &Sam2MaskDownSamplerWeights,
251    in_h: usize,
252    in_w: usize,
253    pfx: &str,
254) -> Result<(HirNodeId, usize, usize)> {
255    let mut cur_h = in_h;
256    let mut cur_w = in_w;
257
258    for (li, level) in w.levels.iter().enumerate() {
259        let out_h = (cur_h + 2 * w.padding - w.kernel) / w.stride + 1;
260        let out_w = (cur_w + 2 * w.padding - w.kernel) / w.stride + 1;
261        let k = w.kernel;
262        let cw = param(
263            g,
264            params,
265            &format!("{pfx}conv{li}_w"),
266            &level.conv_w,
267            &[level.out_c, level.in_c, k, k],
268        );
269        let cb = param(
270            g,
271            params,
272            &format!("{pfx}conv{li}_b"),
273            &level.conv_b,
274            &[level.out_c],
275        );
276        x = conv2d_bias(
277            g,
278            x,
279            cw,
280            cb,
281            1,
282            level.out_c,
283            k,
284            k,
285            [w.stride, w.stride],
286            [w.padding, w.padding],
287            out_h,
288            out_w,
289        );
290        let ln_g = param(
291            g,
292            params,
293            &format!("{pfx}ln{li}_g"),
294            &level.ln_g,
295            &[level.out_c],
296        );
297        let ln_b = param(
298            g,
299            params,
300            &format!("{pfx}ln{li}_b"),
301            &level.ln_b,
302            &[level.out_c],
303        );
304        x = layer_norm2d_nchw(g, x, ln_g, ln_b, LN_EPS);
305        x = g.gelu(x);
306        cur_h = out_h;
307        cur_w = out_w;
308    }
309
310    let last_c = w.levels.last().map(|l| l.out_c).unwrap_or(1);
311    let fw = param(
312        g,
313        params,
314        &format!("{pfx}final_w"),
315        &w.final_conv_w,
316        &[w.embed_dim, last_c, 1, 1],
317    );
318    let fb = param(
319        g,
320        params,
321        &format!("{pfx}final_b"),
322        &w.final_conv_b,
323        &[w.embed_dim],
324    );
325    let out = conv2d_bias(
326        g,
327        x,
328        fw,
329        fb,
330        1,
331        w.embed_dim,
332        1,
333        1,
334        [1, 1],
335        [0, 0],
336        cur_h,
337        cur_w,
338    );
339    Ok((out, cur_h, cur_w))
340}
341
342fn param(
343    g: &mut HirMut<'_>,
344    params: &mut HashMap<String, Vec<f32>>,
345    name: &str,
346    data: &[f32],
347    shape: &[usize],
348) -> HirNodeId {
349    let id = g.param(name, Shape::new(shape, DType::F32));
350    params.insert(name.to_string(), data.to_vec());
351    id
352}
353
354/// Compiled 1×1 conv on NCHW `[C, H, W]` flat.
355pub struct Sam2MemoryConv1x1Compiled {
356    graph: CompiledGraph,
357    in_c: usize,
358    pub out_c: usize,
359    pub h: usize,
360    pub w: usize,
361}
362
363impl Sam2MemoryConv1x1Compiled {
364    pub fn compile(
365        in_c: usize,
366        out_c: usize,
367        h: usize,
368        w: usize,
369        weight: &[f32],
370        bias: &[f32],
371        device: Device,
372    ) -> Result<Self> {
373        Self::compile_with_profile(
374            in_c,
375            out_c,
376            h,
377            w,
378            weight,
379            bias,
380            device,
381            &CompileProfile::sam_encoder(),
382        )
383    }
384
385    pub fn compile_with_profile(
386        in_c: usize,
387        out_c: usize,
388        h: usize,
389        w: usize,
390        weight: &[f32],
391        bias: &[f32],
392        device: Device,
393        profile: &CompileProfile,
394    ) -> Result<Self> {
395        let (graph, params) = build_conv1x1_graph(in_c, out_c, h, w, weight, bias)?;
396        let mut compiled =
397            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
398        for (name, data) in &params {
399            compiled.set_param(name, data);
400        }
401        Ok(Self {
402            graph: compiled,
403            in_c,
404            out_c,
405            h,
406            w,
407        })
408    }
409
410    pub fn run(&mut self, x: &[f32]) -> Result<Vec<f32>> {
411        let expected = self.in_c * self.h * self.w;
412        anyhow::ensure!(
413            x.len() == expected,
414            "conv1x1 input len {} ≠ {} ({}×{}×{})",
415            x.len(),
416            expected,
417            self.in_c,
418            self.h,
419            self.w
420        );
421        let outs = self.graph.run(&[("x", x)]);
422        Ok(outs.into_iter().next().expect("conv1x1 output"))
423    }
424}
425
426fn build_conv1x1_graph(
427    in_c: usize,
428    out_c: usize,
429    h: usize,
430    w: usize,
431    weight: &[f32],
432    bias: &[f32],
433) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
434    let f = DType::F32;
435    let mut hir = HirModule::new("sam2_conv1x1");
436    let mut params = HashMap::new();
437    let mut g = HirMut::new(&mut hir);
438
439    let x = g.input("x", nchw_shape(1, in_c, h, w, f));
440    let wt = param(&mut g, &mut params, "w", weight, &[out_c, in_c, 1, 1]);
441    let bt = param(&mut g, &mut params, "b", bias, &[out_c]);
442    let y = conv2d_bias(&mut g, x, wt, bt, 1, out_c, 1, 1, [1, 1], [0, 0], h, w);
443
444    hir.set_outputs(vec![y]);
445    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
446    Ok((graph, params))
447}
448
449/// ConvNeXt-style `Fuser` (optional input 1×1 + `CXBlock` stack).
450pub struct Sam2MemoryFuserCompiled {
451    graph: CompiledGraph,
452    pub dim: usize,
453    pub h: usize,
454    pub w: usize,
455}
456
457impl Sam2MemoryFuserCompiled {
458    pub fn compile(w: &Sam2FuserWeights, h: usize, ww: usize, device: Device) -> Result<Self> {
459        Self::compile_with_profile(w, h, ww, device, &CompileProfile::sam_encoder())
460    }
461
462    pub fn compile_with_profile(
463        w: &Sam2FuserWeights,
464        h: usize,
465        ww: usize,
466        device: Device,
467        profile: &CompileProfile,
468    ) -> Result<Self> {
469        let (graph, params) = build_fuser_graph(w, h, ww)?;
470        let mut compiled =
471            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
472        for (name, data) in &params {
473            compiled.set_param(name, data);
474        }
475        Ok(Self {
476            graph: compiled,
477            dim: w.dim,
478            h,
479            w: ww,
480        })
481    }
482
483    pub fn run(&mut self, x: &[f32]) -> Result<Vec<f32>> {
484        let expected = self.dim * self.h * self.w;
485        anyhow::ensure!(
486            x.len() == expected,
487            "fuser input len {} ≠ {expected}",
488            x.len()
489        );
490        let outs = self.graph.run(&[("x", x)]);
491        Ok(outs.into_iter().next().expect("fuser output"))
492    }
493}
494
495fn build_fuser_graph(
496    w: &Sam2FuserWeights,
497    h: usize,
498    ww: usize,
499) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
500    let f = DType::F32;
501    let dim = w.dim;
502    let mut hir = HirModule::new("sam2_memory_fuser");
503    let mut params = HashMap::new();
504    let mut g = HirMut::new(&mut hir);
505
506    let mut x = g.input("x", nchw_shape(1, dim, h, ww, f));
507
508    if let (Some(pw), Some(pb)) = (&w.input_proj_w, &w.input_proj_b) {
509        let wt = param(&mut g, &mut params, "input_proj_w", pw, &[dim, dim, 1, 1]);
510        let bt = param(&mut g, &mut params, "input_proj_b", pb, &[dim]);
511        x = conv2d_bias(&mut g, x, wt, bt, 1, dim, 1, 1, [1, 1], [0, 0], h, ww);
512    }
513
514    for (li, layer) in w.layers.iter().enumerate() {
515        x = cx_block_hir(&mut g, &mut params, x, layer, li, h, ww)?;
516    }
517
518    hir.set_outputs(vec![x]);
519    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
520    Ok((graph, params))
521}
522
523fn cx_block_hir(
524    g: &mut HirMut<'_>,
525    params: &mut HashMap<String, Vec<f32>>,
526    x: HirNodeId,
527    w: &Sam2CXBlockWeights,
528    li: usize,
529    h: usize,
530    ww: usize,
531) -> Result<HirNodeId> {
532    let dim = w.dim;
533    let k = w.kernel;
534    let p = w.padding;
535    let residual = x;
536
537    let dw_w = param(
538        g,
539        params,
540        &format!("l{li}_dw_w"),
541        &w.dw_conv_w,
542        &[dim, 1, k, k],
543    );
544    let dw_b = param(g, params, &format!("l{li}_dw_b"), &w.dw_conv_b, &[dim]);
545    let mut y = conv2d_bias_groups(g, x, dw_w, dw_b, 1, dim, k, k, [1, 1], [p, p], dim, h, ww);
546
547    let ln_g = param(g, params, &format!("l{li}_ln_g"), &w.ln_g, &[dim]);
548    let ln_b = param(g, params, &format!("l{li}_ln_b"), &w.ln_b, &[dim]);
549    y = layer_norm2d_nchw(g, y, ln_g, ln_b, LN_EPS);
550
551    let pw1_w = param(
552        g,
553        params,
554        &format!("l{li}_pw1_w"),
555        &w.pw1_w,
556        &[4 * dim, dim, 1, 1],
557    );
558    let pw1_b = param(g, params, &format!("l{li}_pw1_b"), &w.pw1_b, &[4 * dim]);
559    y = conv2d_bias(g, y, pw1_w, pw1_b, 1, 4 * dim, 1, 1, [1, 1], [0, 0], h, ww);
560    y = g.gelu(y);
561
562    let pw2_w = param(
563        g,
564        params,
565        &format!("l{li}_pw2_w"),
566        &w.pw2_w,
567        &[dim, 4 * dim, 1, 1],
568    );
569    let pw2_b = param(g, params, &format!("l{li}_pw2_b"), &w.pw2_b, &[dim]);
570    y = conv2d_bias(g, y, pw2_w, pw2_b, 1, dim, 1, 1, [1, 1], [0, 0], h, ww);
571
572    if let Some(gamma) = &w.gamma {
573        let gparam = param(g, params, &format!("l{li}_gamma"), gamma, &[dim]);
574        let out_shape = g.shape(y).clone();
575        let g4 = g.reshape_(gparam, vec![1, dim as i64, 1, 1]);
576        let scaled = g.add_node(
577            Op::Expand {
578                target_shape: vec![1, dim as i64, h as i64, ww as i64],
579            },
580            vec![g4],
581            out_shape.clone(),
582        );
583        y = g.mul(y, scaled);
584    }
585
586    Ok(g.add(residual, y))
587}