use super::config::{SAM_EMBED_HW, SamEncoderConfig};
use super::preprocess::{SamPreprocessWeights, extract_preprocess_weights};
use anyhow::{Result, anyhow, ensure};
use rlx_core::vision_ops_ir::{bhwc_to_nchw, conv2d_bias, conv2d_no_bias, layer_norm2d_nchw};
use rlx_core::weight_map::WeightMap;
use rlx_ir::HirGraphExt;
use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
use rlx_ir::*;
use std::collections::HashMap;
struct SamBuilder {
hir: HirModule,
params: HashMap<String, Vec<f32>>,
}
impl SamBuilder {
fn new(name: &str) -> Self {
Self {
hir: HirModule::new(name),
params: HashMap::new(),
}
}
fn m(&mut self) -> HirMut<'_> {
HirMut::new(&mut self.hir)
}
}
#[allow(dead_code)]
fn lower_hir(hir: HirModule) -> Result<Graph> {
Graph::from_hir(hir).map_err(|e| anyhow!("{e}"))
}
pub fn build_sam_encoder_hir(
cfg: &SamEncoderConfig,
weights: &mut WeightMap,
) -> Result<(HirModule, HashMap<String, Vec<f32>>, SamPreprocessWeights)> {
let mut b = SamBuilder::new("sam_image_encoder");
let f = DType::F32;
let preprocess = extract_preprocess_weights(weights, cfg)?;
let e = cfg.embed_dim;
let nh = cfg.num_heads;
let dh = cfg.head_dim();
let scale = 1.0 / (dh as f32).sqrt();
let eps = cfg.layer_norm_eps as f32;
let hw = SAM_EMBED_HW;
let s = hw * hw;
let hidden_input = b.m().input("hidden", Shape::new(&[1, s, e], f));
let mut x = hidden_input;
for layer_idx in 0..cfg.depth {
let lp = format!("image_encoder.blocks.{layer_idx}");
let is_global = cfg.global_attn_indexes.contains(&layer_idx);
let ws = if is_global { 0 } else { cfg.window_size };
let n1_g = load_p(&mut b, weights, &format!("{lp}.norm1.weight"), false)?;
let n1_b = load_p(&mut b, weights, &format!("{lp}.norm1.bias"), false)?;
let normed = b.m().ln(x, n1_g, n1_b, eps);
let attn_out = if ws == 0 {
attention_global(
&mut b,
weights,
&lp,
normed,
e,
nh,
dh,
scale,
hw,
cfg.use_rel_pos,
cfg.qkv_bias,
)?
} else {
attention_windowed(
&mut b,
weights,
&lp,
normed,
e,
nh,
dh,
scale,
hw,
ws,
cfg.use_rel_pos,
cfg.qkv_bias,
)?
};
x = b.m().add(x, attn_out);
let n2_g = load_p(&mut b, weights, &format!("{lp}.norm2.weight"), false)?;
let n2_b = load_p(&mut b, weights, &format!("{lp}.norm2.bias"), false)?;
let normed2 = b.m().ln(x, n2_g, n2_b, eps);
let fc1_w = load_p(&mut b, weights, &format!("{lp}.mlp.lin1.weight"), true)?;
let fc1_b = load_p(&mut b, weights, &format!("{lp}.mlp.lin1.bias"), false)?;
let fc2_w = load_p(&mut b, weights, &format!("{lp}.mlp.lin2.weight"), true)?;
let fc2_b = load_p(&mut b, weights, &format!("{lp}.mlp.lin2.bias"), false)?;
let up_mm = b.m().mm(normed2, fc1_w);
let up = b.m().add(up_mm, fc1_b);
let act = b.m().gelu(up);
let down_mm = b.m().mm(act, fc2_w);
let ffn = b.m().add(down_mm, fc2_b);
x = b.m().add(x, ffn);
}
let oc = cfg.out_chans;
let nchw = bhwc_to_nchw(&mut b.m(), x, 1, hw, hw, e);
let c1_w = load_p(&mut b, weights, "image_encoder.neck.0.weight", false)?;
let c1_b = load_p(&mut b, weights, "image_encoder.neck.0.bias", false)?;
let feat = conv2d_bias(
&mut b.m(),
nchw,
c1_w,
c1_b,
1,
oc,
1,
1,
[1, 1],
[0, 0],
hw,
hw,
);
let ln1_g = load_p(&mut b, weights, "image_encoder.neck.1.weight", false)?;
let ln1_b = load_p(&mut b, weights, "image_encoder.neck.1.bias", false)?;
let feat = layer_norm2d_nchw(&mut b.m(), feat, ln1_g, ln1_b, eps);
let c2_w = load_p(&mut b, weights, "image_encoder.neck.2.weight", false)?;
let feat = conv2d_no_bias(&mut b.m(), feat, c2_w, 1, oc, 3, 3, [1, 1], [1, 1], hw, hw);
let ln2_g = load_p(&mut b, weights, "image_encoder.neck.3.weight", false)?;
let ln2_b = load_p(&mut b, weights, "image_encoder.neck.3.bias", false)?;
let out = layer_norm2d_nchw(&mut b.m(), feat, ln2_g, ln2_b, eps);
b.hir.set_outputs(vec![out]);
Ok((b.hir, b.params, preprocess))
}
pub fn build_sam_encoder_graph(
cfg: &SamEncoderConfig,
weights: &mut WeightMap,
) -> Result<(Graph, HashMap<String, Vec<f32>>, SamPreprocessWeights)> {
let built = super::flow::build_sam_encoder_built(cfg, weights)?;
let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
Ok((graph, params, built.preprocess))
}
#[allow(clippy::too_many_arguments)]
fn attention_global(
sb: &mut SamBuilder,
w: &mut WeightMap,
lp: &str,
x: HirNodeId, e: usize,
nh: usize,
dh: usize,
scale: f32,
hw: usize,
use_rel_pos: bool,
qkv_bias: bool,
) -> Result<HirNodeId> {
let s = hw * hw;
decomposed_attention(
sb,
w,
lp,
x,
e,
nh,
dh,
scale,
hw,
hw,
s,
1,
use_rel_pos,
qkv_bias,
)
}
#[allow(clippy::too_many_arguments)]
fn attention_windowed(
sb: &mut SamBuilder,
w: &mut WeightMap,
lp: &str,
x: HirNodeId, e: usize,
nh: usize,
dh: usize,
scale: f32,
hw: usize,
ws: usize,
use_rel_pos: bool,
qkv_bias: bool,
) -> Result<HirNodeId> {
let bhwc = sb.m().reshape_(x, vec![1, hw as i64, hw as i64, e as i64]);
let pad = (ws - hw % ws) % ws;
let hw_p = hw + pad;
let n_win_per_side = hw_p / ws;
let n_win = n_win_per_side * n_win_per_side;
let padded = if pad > 0 {
let z_h = pad_zero_param(sb, &format!("{lp}.attn._pad_h"), &[1, pad, hw, e]);
let p1 = sb.m().concat_(vec![bhwc, z_h], 1); let z_w = pad_zero_param(sb, &format!("{lp}.attn._pad_w"), &[1, hw_p, pad, e]);
sb.m().concat_(vec![p1, z_w], 2) } else {
bhwc
};
let reshaped = sb.m().reshape_(
padded,
vec![
1,
n_win_per_side as i64,
ws as i64,
n_win_per_side as i64,
ws as i64,
e as i64,
],
);
let transposed = sb.m().transpose_(reshaped, vec![0, 1, 3, 2, 4, 5]);
let windowed = sb.m().reshape_(
transposed,
vec![n_win as i64, ws as i64, ws as i64, e as i64],
);
let win_flat = sb
.m()
.reshape_(windowed, vec![n_win as i64, (ws * ws) as i64, e as i64]);
let attn_out = decomposed_attention(
sb,
w,
lp,
win_flat,
e,
nh,
dh,
scale,
ws,
ws,
ws * ws,
n_win,
use_rel_pos,
qkv_bias,
)?;
let un = sb
.m()
.reshape_(attn_out, vec![n_win as i64, ws as i64, ws as i64, e as i64]);
let un = sb.m().reshape_(
un,
vec![
1,
n_win_per_side as i64,
n_win_per_side as i64,
ws as i64,
ws as i64,
e as i64,
],
);
let un = sb.m().transpose_(un, vec![0, 1, 3, 2, 4, 5]);
let un = sb
.m()
.reshape_(un, vec![1, hw_p as i64, hw_p as i64, e as i64]);
let un = if pad > 0 {
let cropped_h = sb.m().narrow_(un, 1, 0, hw);
sb.m().narrow_(cropped_h, 2, 0, hw)
} else {
un
};
Ok(sb.m().reshape_(un, vec![1, (hw * hw) as i64, e as i64]))
}
#[allow(clippy::too_many_arguments)]
fn decomposed_attention(
sb: &mut SamBuilder,
w: &mut WeightMap,
lp: &str,
x: HirNodeId, e: usize,
nh: usize,
dh: usize,
scale: f32,
h: usize,
w_dim: usize,
s: usize, batch: usize,
use_rel_pos: bool,
qkv_bias: bool,
) -> Result<HirNodeId> {
let qkv_w_node = load_p(sb, w, &format!("{lp}.attn.qkv.weight"), true)?;
let qkv_b_node = if qkv_bias {
Some(load_p(sb, w, &format!("{lp}.attn.qkv.bias"), false)?)
} else {
None
};
let qkv_mm = sb.m().mm(x, qkv_w_node); let qkv = if let Some(b) = qkv_b_node {
sb.m().add(qkv_mm, b)
} else {
qkv_mm
};
let qkv5 = sb
.m()
.reshape_(qkv, vec![batch as i64, s as i64, 3, nh as i64, dh as i64]);
let qkv_perm = sb.m().transpose_(qkv5, vec![2, 0, 3, 1, 4]); let qkv_flat = sb
.m()
.reshape_(qkv_perm, vec![3, (batch * nh) as i64, s as i64, dh as i64]);
let q = sb.m().narrow_(qkv_flat, 0, 0, 1);
let q = sb
.m()
.reshape_(q, vec![(batch * nh) as i64, s as i64, dh as i64]);
let k = sb.m().narrow_(qkv_flat, 0, 1, 1);
let k = sb
.m()
.reshape_(k, vec![(batch * nh) as i64, s as i64, dh as i64]);
let v = sb.m().narrow_(qkv_flat, 0, 2, 1);
let v = sb
.m()
.reshape_(v, vec![(batch * nh) as i64, s as i64, dh as i64]);
let scale_node = scalar_param(sb, &format!("{lp}.attn._scale"), scale);
let q_scaled = sb.m().mul(q, scale_node);
let k_t = sb.m().transpose_(k, vec![0, 2, 1]); let scores = sb.m().mm(q_scaled, k_t);
let scores = if use_rel_pos {
let (mut r_h_data, mut r_w_data) = extract_rel_pos(w, lp, h, w_dim, dh)?;
if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELPOS") {
r_h_data.iter_mut().for_each(|v| *v = 0.0);
r_w_data.iter_mut().for_each(|v| *v = 0.0);
}
if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELH") {
r_h_data.iter_mut().for_each(|v| *v = 0.0);
}
if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELW") {
r_w_data.iter_mut().for_each(|v| *v = 0.0);
}
let r_h_node = const_param(
sb,
&format!("{lp}.attn._rel_h_indexed"),
&[h, h, dh],
r_h_data,
);
let r_w_node = const_param(
sb,
&format!("{lp}.attn._rel_w_indexed"),
&[w_dim, w_dim, dh],
r_w_data,
);
add_decomposed_rel_pos(sb, scores, q, r_h_node, r_w_node, batch, nh, h, w_dim, dh)?
} else {
scores
};
let attn_w = sb.m().sm(scores, -1);
let attn_v = sb.m().mm(attn_w, v);
let reshaped = sb
.m()
.reshape_(attn_v, vec![batch as i64, nh as i64, s as i64, dh as i64]);
let perm = sb.m().transpose_(reshaped, vec![0, 2, 1, 3]); let merged = sb
.m()
.reshape_(perm, vec![batch as i64, s as i64, e as i64]);
let proj_w = load_p(sb, w, &format!("{lp}.attn.proj.weight"), true)?;
let proj_b = load_p(sb, w, &format!("{lp}.attn.proj.bias"), false)?;
let proj_mm = sb.m().mm(merged, proj_w);
Ok(sb.m().add(proj_mm, proj_b))
}
#[allow(clippy::too_many_arguments)]
fn add_decomposed_rel_pos(
sb: &mut SamBuilder,
scores: HirNodeId, q: HirNodeId, r_h: HirNodeId, r_w: HirNodeId, batch: usize,
nh: usize,
h: usize,
w: usize,
dh: usize,
) -> Result<HirNodeId> {
let bh = batch * nh;
let r_q = sb
.m()
.reshape_(q, vec![bh as i64, h as i64, w as i64, dh as i64]);
let mut rel_h_slices: Vec<HirNodeId> = Vec::with_capacity(h);
for h_q in 0..h {
let rq_slice = sb.m().narrow_(r_q, 1, h_q, 1); let rq_slice = sb
.m()
.reshape_(rq_slice, vec![bh as i64, w as i64, dh as i64]);
let rh_slice = sb.m().narrow_(r_h, 0, h_q, 1); let rh_slice = sb.m().reshape_(rh_slice, vec![h as i64, dh as i64]); let rh_t = sb.m().transpose_(rh_slice, vec![1, 0]); let mm = sb.m().mm(rq_slice, rh_t); let mm5 = sb.m().reshape_(mm, vec![bh as i64, 1, w as i64, h as i64]);
rel_h_slices.push(mm5);
}
let rel_h_4d = sb.m().concat_(rel_h_slices, 1);
let mut rel_w_slices: Vec<HirNodeId> = Vec::with_capacity(w);
for w_q in 0..w {
let rq_slice = sb.m().narrow_(r_q, 2, w_q, 1); let rq_slice = sb
.m()
.reshape_(rq_slice, vec![bh as i64, h as i64, dh as i64]);
let rw_slice = sb.m().narrow_(r_w, 0, w_q, 1); let rw_slice = sb.m().reshape_(rw_slice, vec![w as i64, dh as i64]); let rw_t = sb.m().transpose_(rw_slice, vec![1, 0]); let mm = sb.m().mm(rq_slice, rw_t); let mm5 = sb.m().reshape_(mm, vec![bh as i64, h as i64, 1, w as i64]);
rel_w_slices.push(mm5);
}
let rel_w_4d = sb.m().concat_(rel_w_slices, 2);
let scores_5d = sb.m().reshape_(
scores,
vec![bh as i64, h as i64, w as i64, h as i64, w as i64],
);
let rel_h_5d = sb
.m()
.reshape_(rel_h_4d, vec![bh as i64, h as i64, w as i64, h as i64, 1]);
let rel_h_tiled = {
let mut copies = Vec::with_capacity(w);
for _ in 0..w {
copies.push(rel_h_5d);
}
sb.m().concat_(copies, 4) };
let rel_w_5d = sb
.m()
.reshape_(rel_w_4d, vec![bh as i64, h as i64, w as i64, 1, w as i64]);
let rel_w_tiled = {
let mut copies = Vec::with_capacity(h);
for _ in 0..h {
copies.push(rel_w_5d);
}
sb.m().concat_(copies, 3) };
let s1 = sb.m().add(scores_5d, rel_h_tiled);
let s2 = sb.m().add(s1, rel_w_tiled);
Ok(sb
.m()
.reshape_(s2, vec![bh as i64, (h * w) as i64, (h * w) as i64]))
}
fn extract_rel_pos(
weights: &mut WeightMap,
lp: &str,
h: usize,
w: usize,
dh: usize,
) -> Result<(Vec<f32>, Vec<f32>)> {
let (rel_h_raw, rh_shape) = weights.take(&format!("{lp}.attn.rel_pos_h"))?;
let (rel_w_raw, rw_shape) = weights.take(&format!("{lp}.attn.rel_pos_w"))?;
ensure!(
rh_shape == vec![2 * h - 1, dh],
"{lp}.attn.rel_pos_h expected [{}, {dh}], got {rh_shape:?}",
2 * h - 1
);
ensure!(
rw_shape == vec![2 * w - 1, dh],
"{lp}.attn.rel_pos_w expected [{}, {dh}], got {rw_shape:?}",
2 * w - 1
);
let mut r_h = vec![0f32; h * h * dh];
for q in 0..h {
for k in 0..h {
let idx = (q as isize - k as isize + (h as isize - 1)) as usize;
let src = &rel_h_raw[idx * dh..(idx + 1) * dh];
let dst = &mut r_h[(q * h + k) * dh..(q * h + k + 1) * dh];
dst.copy_from_slice(src);
}
}
let mut r_w = vec![0f32; w * w * dh];
for q in 0..w {
for k in 0..w {
let idx = (q as isize - k as isize + (w as isize - 1)) as usize;
let src = &rel_w_raw[idx * dh..(idx + 1) * dh];
let dst = &mut r_w[(q * w + k) * dh..(q * w + k + 1) * dh];
dst.copy_from_slice(src);
}
}
Ok((r_h, r_w))
}
pub struct NeckWeights {
pub conv1_w: Vec<f32>, pub ln1_g: Vec<f32>, pub ln1_b: Vec<f32>,
pub conv2_w: Vec<f32>, pub ln2_g: Vec<f32>,
pub ln2_b: Vec<f32>,
pub embed_dim: usize,
pub out_chans: usize,
pub eps: f32,
}
#[allow(dead_code)]
fn extract_neck_weights(weights: &mut WeightMap, cfg: &SamEncoderConfig) -> Result<NeckWeights> {
let (conv1_w_raw, c1_shape) = weights.take("image_encoder.neck.0.weight")?;
ensure!(
c1_shape == vec![cfg.out_chans, cfg.embed_dim, 1, 1],
"neck.0.weight expected [{}, {}, 1, 1], got {c1_shape:?}",
cfg.out_chans,
cfg.embed_dim
);
let conv1_w = conv1_w_raw; let (ln1_g, _) = weights.take("image_encoder.neck.1.weight")?;
let (ln1_b, _) = weights.take("image_encoder.neck.1.bias")?;
let (conv2_w, c2_shape) = weights.take("image_encoder.neck.2.weight")?;
ensure!(
c2_shape == vec![cfg.out_chans, cfg.out_chans, 3, 3],
"neck.2.weight expected [{}, {}, 3, 3], got {c2_shape:?}",
cfg.out_chans,
cfg.out_chans
);
let (ln2_g, _) = weights.take("image_encoder.neck.3.weight")?;
let (ln2_b, _) = weights.take("image_encoder.neck.3.bias")?;
Ok(NeckWeights {
conv1_w,
ln1_g,
ln1_b,
conv2_w,
ln2_g,
ln2_b,
embed_dim: cfg.embed_dim,
out_chans: cfg.out_chans,
eps: cfg.layer_norm_eps as f32,
})
}
pub fn apply_neck_host(neck: &NeckWeights, body_out: &[f32], hw: usize) -> Vec<f32> {
let e = neck.embed_dim;
let oc = neck.out_chans;
let eps = neck.eps;
let s = hw * hw;
let mut feat = vec![0f32; s * oc]; for si in 0..s {
for oi in 0..oc {
let mut acc = 0f32;
for ei in 0..e {
acc += body_out[si * e + ei] * neck.conv1_w[oi * e + ei];
}
feat[si * oc + oi] = acc;
}
}
layernorm2d_inplace(&mut feat, s, oc, &neck.ln1_g, &neck.ln1_b, eps);
let mut nchw = vec![0f32; oc * hw * hw];
for y in 0..hw {
for x in 0..hw {
for c in 0..oc {
nchw[c * hw * hw + y * hw + x] = feat[(y * hw + x) * oc + c];
}
}
}
let conv2_out = conv2d_3x3_pad1(&nchw, oc, oc, hw, hw, &neck.conv2_w);
let mut bhwc = vec![0f32; s * oc];
for c in 0..oc {
for y in 0..hw {
for x in 0..hw {
bhwc[(y * hw + x) * oc + c] = conv2_out[c * hw * hw + y * hw + x];
}
}
}
layernorm2d_inplace(&mut bhwc, s, oc, &neck.ln2_g, &neck.ln2_b, eps);
let mut out_nchw = vec![0f32; oc * hw * hw];
for y in 0..hw {
for x in 0..hw {
for c in 0..oc {
out_nchw[c * hw * hw + y * hw + x] = bhwc[(y * hw + x) * oc + c];
}
}
}
out_nchw
}
fn layernorm2d_inplace(data: &mut [f32], s: usize, c: usize, g: &[f32], b: &[f32], eps: f32) {
for si in 0..s {
let row = &mut data[si * c..(si + 1) * c];
let mean: f32 = row.iter().sum::<f32>() / c as f32;
let var: f32 = row.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / c as f32;
let inv = 1.0 / (var + eps).sqrt();
for k in 0..c {
row[k] = (row[k] - mean) * inv * g[k] + b[k];
}
}
}
fn conv2d_3x3_pad1(
input: &[f32],
in_c: usize,
out_c: usize,
h: usize,
w: usize,
weight: &[f32], ) -> Vec<f32> {
let mut out = vec![0f32; out_c * h * w];
for oc in 0..out_c {
for y in 0..h {
for x in 0..w {
let mut acc = 0f32;
for ic in 0..in_c {
for ky in 0..3 {
let iy = y as isize + ky as isize - 1;
if iy < 0 || iy >= h as isize {
continue;
}
for kx in 0..3 {
let ix = x as isize + kx as isize - 1;
if ix < 0 || ix >= w as isize {
continue;
}
let v = input[ic * h * w + iy as usize * w + ix as usize];
let wi = ((oc * in_c + ic) * 3 + ky) * 3 + kx;
acc += v * weight[wi];
}
}
}
out[oc * h * w + y * w + x] = acc;
}
}
}
out
}
fn load_p(
sb: &mut SamBuilder,
weights: &mut WeightMap,
key: &str,
transpose: bool,
) -> Result<HirNodeId> {
let (data, shape) = if transpose {
weights
.take_transposed(key)
.map_err(|e| anyhow!("transpose-load `{key}`: {e}"))?
} else {
weights
.take(key)
.map_err(|e| anyhow!("load `{key}`: {e}"))?
};
let name = key.to_string();
let id = sb.m().param(&name, Shape::new(&shape, DType::F32));
sb.params.insert(name, data);
Ok(id)
}
#[allow(dead_code)]
fn scalar_param(sb: &mut SamBuilder, name: &str, value: f32) -> HirNodeId {
let id = sb.m().param(name, Shape::new(&[1], DType::F32));
sb.params.insert(name.to_string(), vec![value]);
id
}
fn const_param(sb: &mut SamBuilder, name: &str, shape: &[usize], data: Vec<f32>) -> HirNodeId {
let id = sb.m().param(name, Shape::new(shape, DType::F32));
sb.params.insert(name.to_string(), data);
id
}
fn pad_zero_param(sb: &mut SamBuilder, name: &str, shape: &[usize]) -> HirNodeId {
let n: usize = shape.iter().product();
let id = sb.m().param(name, Shape::new(shape, DType::F32));
sb.params.insert(name.to_string(), vec![0f32; n]);
id
}