rlx-sam 0.2.4

Segment Anything Model (SAM v1) for RLX
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

//! SAM v1 prompt-encoder mask downscale (IR).

use super::config::SAM_EMBED_HW;
use super::prompt_encoder::PromptEncoderWeights;
use anyhow::Result;
use rlx_flow::CompileProfile;
use rlx_runtime::Device;
use rlx_sam_ir::mask_prompt_ir::{MaskDownscaleWeights, PromptMaskCompiled};

pub struct SamPromptMaskCompiled(PromptMaskCompiled);

impl SamPromptMaskCompiled {
    pub fn compile(w: &PromptEncoderWeights, device: Device) -> Result<Self> {
        Self::compile_with_profile(w, device, &CompileProfile::sam_encoder())
    }

    pub fn compile_with_profile(
        w: &PromptEncoderWeights,
        device: Device,
        profile: &CompileProfile,
    ) -> Result<Self> {
        let in_h = 4 * SAM_EMBED_HW;
        let md = mask_weights(w);
        Ok(Self(PromptMaskCompiled::compile_with_profile(
            md, in_h, in_h, device, profile,
        )?))
    }

    pub fn run(&mut self, mask: &[f32], in_h: usize, in_w: usize) -> Result<Vec<f32>> {
        self.0.run(mask, in_h, in_w)
    }
}

fn mask_weights<'a>(w: &'a PromptEncoderWeights) -> MaskDownscaleWeights<'a> {
    MaskDownscaleWeights {
        mask_in_chans: w.mask_in_chans,
        embed_dim: w.embed_dim,
        mask_conv1_w: &w.mask_conv1_w,
        mask_conv1_b: &w.mask_conv1_b,
        mask_ln1_g: &w.mask_ln1_g,
        mask_ln1_b: &w.mask_ln1_b,
        mask_conv2_w: &w.mask_conv2_w,
        mask_conv2_b: &w.mask_conv2_b,
        mask_ln2_g: &w.mask_ln2_g,
        mask_ln2_b: &w.mask_ln2_b,
        mask_conv3_w: &w.mask_conv3_w,
        mask_conv3_b: &w.mask_conv3_b,
    }
}