use super::config::SAM_EMBED_HW;
use super::prompt_encoder::PromptEncoderWeights;
use anyhow::Result;
use rlx_flow::CompileProfile;
use rlx_runtime::Device;
use rlx_sam_ir::mask_prompt_ir::{MaskDownscaleWeights, PromptMaskCompiled};
pub struct SamPromptMaskCompiled(PromptMaskCompiled);
impl SamPromptMaskCompiled {
pub fn compile(w: &PromptEncoderWeights, device: Device) -> Result<Self> {
Self::compile_with_profile(w, device, &CompileProfile::sam_encoder())
}
pub fn compile_with_profile(
w: &PromptEncoderWeights,
device: Device,
profile: &CompileProfile,
) -> Result<Self> {
let in_h = 4 * SAM_EMBED_HW;
let md = mask_weights(w);
Ok(Self(PromptMaskCompiled::compile_with_profile(
md, in_h, in_h, device, profile,
)?))
}
pub fn run(&mut self, mask: &[f32], in_h: usize, in_w: usize) -> Result<Vec<f32>> {
self.0.run(mask, in_h, in_w)
}
}
fn mask_weights<'a>(w: &'a PromptEncoderWeights) -> MaskDownscaleWeights<'a> {
MaskDownscaleWeights {
mask_in_chans: w.mask_in_chans,
embed_dim: w.embed_dim,
mask_conv1_w: &w.mask_conv1_w,
mask_conv1_b: &w.mask_conv1_b,
mask_ln1_g: &w.mask_ln1_g,
mask_ln1_b: &w.mask_ln1_b,
mask_conv2_w: &w.mask_conv2_w,
mask_conv2_b: &w.mask_conv2_b,
mask_ln2_g: &w.mask_ln2_g,
mask_ln2_b: &w.mask_ln2_b,
mask_conv3_w: &w.mask_conv3_w,
mask_conv3_b: &w.mask_conv3_b,
}
}