use crate::gguf_ir::{linear_gguf_bias, packed_linear_for_key};
use anyhow::Result;
use rlx_core::vision_ops_ir::{conv2d_bias, nchw_shape};
use rlx_flow::{CompileProfile, GgufPackedParams};
use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
use rlx_ir::{DType, Graph, HirGraphExt, Shape};
use rlx_runtime::{CompiledGraph, Device};
use std::collections::HashMap;
const D_MODEL: usize = 256;
const GN_GROUPS: usize = 8;
type ConvGraphParts = (
Graph,
HashMap<String, Vec<f32>>,
Vec<(String, Vec<u8>, DType)>,
);
pub struct Sam3PixelDecoderStepCompiled {
graph: CompiledGraph,
pub out_h: usize,
pub out_w: usize,
}
impl Sam3PixelDecoderStepCompiled {
pub fn compile(
prev_h: usize,
prev_w: usize,
out_h: usize,
out_w: usize,
conv_w: &[f32],
conv_b: &[f32],
gn_w: &[f32],
gn_b: &[f32],
device: Device,
) -> Result<Self> {
Self::compile_with_profile(
prev_h,
prev_w,
out_h,
out_w,
conv_w,
conv_b,
gn_w,
gn_b,
device,
&CompileProfile::sam3(),
)
}
pub fn compile_with_profile(
prev_h: usize,
prev_w: usize,
out_h: usize,
out_w: usize,
conv_w: &[f32],
conv_b: &[f32],
gn_w: &[f32],
gn_b: &[f32],
device: Device,
profile: &CompileProfile,
) -> Result<Self> {
anyhow::ensure!(
out_h == prev_h * 2 && out_w == prev_w * 2,
"pixel_decoder step expects 2× upsample {prev_h}×{prev_w} → {out_h}×{out_w}"
);
let (graph, params) =
build_pixel_step_graph(prev_h, prev_w, out_h, out_w, conv_w, conv_b, gn_w, gn_b)?;
let mut compiled =
rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
for (name, data) in ¶ms {
compiled.set_param(name, data);
}
Ok(Self {
graph: compiled,
out_h,
out_w,
})
}
pub fn run(&mut self, prev: &[f32], curr: &[f32]) -> Result<Vec<f32>> {
let n = D_MODEL * self.out_h * self.out_w;
anyhow::ensure!(prev.len() == n / 4 && curr.len() == n);
let outs = self.graph.run(&[("prev", prev), ("curr", curr)]);
Ok(outs.into_iter().next().expect("pixel_decoder step output"))
}
}
pub struct Sam3Conv1x1Compiled {
graph: CompiledGraph,
pub out_c: usize,
pub h: usize,
pub w: usize,
}
impl Sam3Conv1x1Compiled {
pub fn compile(
in_c: usize,
out_c: usize,
h: usize,
w: usize,
weight: &[f32],
bias: &[f32],
device: Device,
) -> Result<Self> {
Self::compile_with_profile(
in_c,
out_c,
h,
w,
weight,
bias,
device,
&CompileProfile::sam3(),
)
}
pub fn compile_with_profile(
in_c: usize,
out_c: usize,
h: usize,
w: usize,
weight: &[f32],
bias: &[f32],
device: Device,
profile: &CompileProfile,
) -> Result<Self> {
let (graph, params, typed) = build_conv1x1_graph(in_c, out_c, h, w, weight, bias)?;
let mut compiled =
rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
for (name, data) in ¶ms {
compiled.set_param(name, data);
}
rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
Ok(Self {
graph: compiled,
out_c,
h,
w,
})
}
pub fn compile_with_gguf(
in_c: usize,
out_c: usize,
h: usize,
w: usize,
weight: &[f32],
bias: &[f32],
gguf_key: Option<&str>,
gguf_packed: &GgufPackedParams,
device: Device,
profile: &CompileProfile,
) -> Result<Self> {
let (graph, params, typed) = if let (Some(key), Some(p)) = (
gguf_key,
gguf_key.and_then(|k| packed_linear_for_key(Some(gguf_packed), k)),
) {
build_conv1x1_graph_gguf(in_c, out_c, h, w, p, bias, key)?
} else {
anyhow::ensure!(
!weight.is_empty(),
"conv1x1: missing F32 weights and GGUF key"
);
build_conv1x1_graph(in_c, out_c, h, w, weight, bias)?
};
let mut compiled =
rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
for (name, data) in ¶ms {
compiled.set_param(name, data);
}
rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
Ok(Self {
graph: compiled,
out_c,
h,
w,
})
}
pub fn run(&mut self, x: &[f32]) -> Result<Vec<f32>> {
anyhow::ensure!(x.len() == D_MODEL * self.h * self.w);
let outs = self.graph.run(&[("x", x)]);
Ok(outs.into_iter().next().expect("conv1x1 output"))
}
}
fn build_pixel_step_graph(
prev_h: usize,
prev_w: usize,
out_h: usize,
out_w: usize,
conv_w: &[f32],
conv_b: &[f32],
gn_w: &[f32],
gn_b: &[f32],
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
let f = DType::F32;
let mut hir = HirModule::new("sam3_pixel_decoder_step");
let mut params = HashMap::new();
let mut g = HirMut::new(&mut hir);
let prev = g.input("prev", Shape::new(&[1, D_MODEL, prev_h, prev_w], f));
let curr = g.input("curr", Shape::new(&[1, D_MODEL, out_h, out_w], f));
let up = g.resize_nearest_2x(prev);
let combined = g.add(curr, up);
let cw = param_f32(
&mut g,
&mut params,
"conv_w",
conv_w,
&[D_MODEL, D_MODEL, 3, 3],
);
let cb = param_f32(&mut g, &mut params, "conv_b", conv_b, &[D_MODEL]);
let mut y = conv2d_bias(
&mut g,
combined,
cw,
cb,
1,
D_MODEL,
3,
3,
[1, 1],
[1, 1],
out_h,
out_w,
);
let gnw = param_f32(&mut g, &mut params, "gn_w", gn_w, &[D_MODEL]);
let gnb = param_f32(&mut g, &mut params, "gn_b", gn_b, &[D_MODEL]);
y = g.group_norm(y, gnw, gnb, GN_GROUPS, 1e-5);
let out = g.relu(y);
hir.set_outputs(vec![out]);
let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
Ok((graph, params))
}
fn build_conv1x1_graph(
in_c: usize,
out_c: usize,
h: usize,
w: usize,
weight: &[f32],
bias: &[f32],
) -> Result<ConvGraphParts> {
let f = DType::F32;
let mut hir = HirModule::new("sam3_conv1x1");
let mut params = HashMap::new();
let mut g = HirMut::new(&mut hir);
let x = g.input("x", nchw_shape(1, in_c, h, w, f));
let wt = param_f32(&mut g, &mut params, "w", weight, &[out_c, in_c, 1, 1]);
let bt = param_f32(&mut g, &mut params, "b", bias, &[out_c]);
let y = conv2d_bias(&mut g, x, wt, bt, 1, out_c, 1, 1, [1, 1], [0, 0], h, w);
hir.set_outputs(vec![y]);
let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
Ok((graph, params, Vec::new()))
}
fn build_conv1x1_graph_gguf(
in_c: usize,
out_c: usize,
h: usize,
w: usize,
p: &rlx_flow::GgufPackedLinear,
bias: &[f32],
gguf_key: &str,
) -> Result<ConvGraphParts> {
let f = DType::F32;
let mut hir = HirModule::new("sam3_conv1x1_gguf");
let mut params = HashMap::new();
let mut typed = Vec::new();
let mut gguf_cache = HashMap::new();
let mut g = HirMut::new(&mut hir);
let x = g.input("x", nchw_shape(1, in_c, h, w, f));
let spatial = (h * w) as i64;
let flat = g.reshape_(x, vec![1, spatial, in_c as i64]);
let stem = gguf_key.strip_suffix(".weight").unwrap_or(gguf_key);
let y_flat = linear_gguf_bias(
&mut g,
&mut params,
&mut typed,
&mut gguf_cache,
stem,
p,
flat,
bias,
in_c,
out_c,
)?;
let y = g.reshape_(y_flat, vec![1, out_c as i64, h as i64, w as i64]);
hir.set_outputs(vec![y]);
let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
Ok((graph, params, typed))
}
fn param_f32(
g: &mut HirMut<'_>,
params: &mut HashMap<String, Vec<f32>>,
name: &str,
data: &[f32],
shape: &[usize],
) -> HirNodeId {
let id = g.param(name, Shape::new(shape, DType::F32));
params.insert(name.to_string(), data.to_vec());
id
}
pub fn compile_pixel_decoder_steps(
pixel_conv_w: &[Vec<f32>],
pixel_conv_b: &[Vec<f32>],
pixel_gn_w: &[Vec<f32>],
pixel_gn_b: &[Vec<f32>],
trunk_grid: usize,
device: Device,
profile: &CompileProfile,
) -> Result<Vec<Sam3PixelDecoderStepCompiled>> {
let g0 = trunk_grid;
let g1 = trunk_grid * 2;
let g2 = trunk_grid * 4;
let steps = [(g0, g0, g1, g1, 0usize), (g1, g1, g2, g2, 1usize)];
steps
.iter()
.map(|(ph, pw, oh, ow, i)| {
Sam3PixelDecoderStepCompiled::compile_with_profile(
*ph,
*pw,
*oh,
*ow,
&pixel_conv_w[*i],
&pixel_conv_b[*i],
&pixel_gn_w[*i],
&pixel_gn_b[*i],
device,
profile,
)
})
.collect()
}