use super::config::{SAM2_PATCH_GRID, SAM2_Q_POOL_COUNT, SAM2_Q_STRIDE, Sam2HieraConfig};
use super::fpn_neck::{FpnNeckWeights, extract_fpn_weights};
use super::preprocess::{Sam2PreprocessWeights, extract_preprocess_weights};
use anyhow::{Result, anyhow};
use rlx_core::weight_map::WeightMap;
use rlx_ir::HirGraphExt;
use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
use rlx_ir::op::ReduceOp;
use rlx_ir::*;
use std::collections::HashMap;
struct Sam2Builder {
hir: HirModule,
params: HashMap<String, Vec<f32>>,
}
impl Sam2Builder {
fn new(name: &str) -> Self {
Self {
hir: HirModule::new(name),
params: HashMap::new(),
}
}
fn m(&mut self) -> HirMut<'_> {
HirMut::new(&mut self.hir)
}
}
#[allow(dead_code)]
fn lower_hir(hir: HirModule) -> Result<Graph> {
Graph::from_hir(hir).map_err(|e| anyhow!("{e}"))
}
fn max_reduce(
sb: &mut Sam2Builder,
x: HirNodeId,
axes: Vec<usize>,
keep_dim: bool,
out_shape: Shape,
) -> HirNodeId {
sb.m().add_node(
Op::Reduce {
op: ReduceOp::Max,
axes,
keep_dim,
},
vec![x],
out_shape,
)
}
#[allow(clippy::type_complexity)]
pub fn build_sam2_image_encoder_hir(
cfg: &Sam2HieraConfig,
weights: &mut WeightMap,
) -> Result<(
HirModule,
HashMap<String, Vec<f32>>,
Sam2PreprocessWeights,
FpnNeckWeights,
)> {
let mut b = Sam2Builder::new("sam2_hiera_image_encoder");
let f = DType::F32;
let preprocess = extract_preprocess_weights(weights, cfg)?;
let grid0 = SAM2_PATCH_GRID;
let e0 = cfg.embed_dim;
let eps = cfg.layer_norm_eps as f32;
let mut x = b.m().input("hidden", Shape::new(&[1, grid0, grid0, e0], f));
let q_pool_blocks = cfg.q_pool_block_indices();
let mut stage = 0usize;
let mut h_curr = grid0;
let mut w_curr = grid0;
let mut dim_curr = e0;
let mut stage_outputs: Vec<HirNodeId> = Vec::with_capacity(cfg.stages.len());
let total = cfg.total_blocks();
for i in 0..total {
let lp = format!("image_encoder.trunk.blocks.{i}");
let is_q_pool = q_pool_blocks.contains(&i);
let dim_in = dim_curr;
let stage_after = if is_q_pool { stage + 1 } else { stage };
let dim_out = cfg.embed_dim_at_stage(stage_after);
let num_heads = cfg.num_heads_at_stage(stage_after);
let head_dim = dim_out / num_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let is_global = cfg.global_att_blocks.contains(&i);
let ws_old = if is_global {
0
} else {
cfg.window_size_at_stage(stage)
};
x = multi_scale_block(
&mut b,
weights,
&lp,
x,
h_curr,
w_curr,
dim_in,
dim_out,
num_heads,
head_dim,
scale,
ws_old,
is_q_pool,
eps,
cfg.mlp_ratio,
cfg.qkv_bias,
)?;
if is_q_pool {
stage += 1;
h_curr /= SAM2_Q_STRIDE;
w_curr /= SAM2_Q_STRIDE;
dim_curr = dim_out;
}
let stage_end = (i + 1 == total) || q_pool_blocks.contains(&(i + 1));
if stage_end {
stage_outputs.push(x);
}
}
debug_assert_eq!(stage_outputs.len(), cfg.stages.len());
debug_assert_eq!(stage_outputs.len(), SAM2_Q_POOL_COUNT + 1);
b.hir.set_outputs(stage_outputs);
let fpn = extract_fpn_weights(weights, cfg)?;
Ok((b.hir, b.params, preprocess, fpn))
}
#[allow(clippy::type_complexity)]
pub fn build_sam2_image_encoder_graph(
cfg: &Sam2HieraConfig,
weights: &mut WeightMap,
) -> Result<(
Graph,
HashMap<String, Vec<f32>>,
Sam2PreprocessWeights,
FpnNeckWeights,
)> {
let built = super::flow::build_sam2_image_encoder_built(cfg, weights)?;
let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
Ok((graph, params, built.preprocess, built.neck))
}
#[allow(clippy::too_many_arguments)]
fn multi_scale_block(
sb: &mut Sam2Builder,
w: &mut WeightMap,
lp: &str,
x: HirNodeId, h: usize,
wd: usize,
dim_in: usize,
dim_out: usize,
num_heads: usize,
head_dim: usize,
scale: f32,
ws_old: usize, is_q_pool: bool,
eps: f32,
mlp_ratio: f64,
qkv_bias: bool,
) -> Result<HirNodeId> {
let n1_g = load_p(sb, w, &format!("{lp}.norm1.weight"), false)?;
let n1_b = load_p(sb, w, &format!("{lp}.norm1.bias"), false)?;
let normed = sb.m().ln(x, n1_g, n1_b, eps);
let shortcut = if dim_in != dim_out {
let proj_w = load_p(sb, w, &format!("{lp}.proj.weight"), true)?;
let proj_b = load_p(sb, w, &format!("{lp}.proj.bias"), false)?;
let proj_mm = sb.m().mm(normed, proj_w);
let projected = sb.m().add(proj_mm, proj_b);
if is_q_pool {
qpool_2x2(sb, projected, 1, h, wd, dim_out)
} else {
projected
}
} else {
x
};
let (attn_out, h_new, w_new) = if ws_old == 0 {
let out = multi_scale_attention_global(
sb, w, lp, normed, h, wd, dim_in, dim_out, num_heads, head_dim, scale, is_q_pool,
qkv_bias,
)?;
let (hh, ww) = if is_q_pool {
(h / SAM2_Q_STRIDE, wd / SAM2_Q_STRIDE)
} else {
(h, wd)
};
(out, hh, ww)
} else {
let out = multi_scale_attention_windowed(
sb, w, lp, normed, h, wd, dim_in, dim_out, num_heads, head_dim, scale, ws_old,
is_q_pool, qkv_bias,
)?;
let (hh, ww) = if is_q_pool {
(h / SAM2_Q_STRIDE, wd / SAM2_Q_STRIDE)
} else {
(h, wd)
};
(out, hh, ww)
};
let _ = (h_new, w_new);
let x = sb.m().add(shortcut, attn_out);
let n2_g = load_p(sb, w, &format!("{lp}.norm2.weight"), false)?;
let n2_b = load_p(sb, w, &format!("{lp}.norm2.bias"), false)?;
let normed2 = sb.m().ln(x, n2_g, n2_b, eps);
let hidden = (dim_out as f64 * mlp_ratio) as usize;
let fc1_w = load_p(sb, w, &format!("{lp}.mlp.layers.0.weight"), true)?;
let fc1_b = load_p(sb, w, &format!("{lp}.mlp.layers.0.bias"), false)?;
let fc2_w = load_p(sb, w, &format!("{lp}.mlp.layers.1.weight"), true)?;
let fc2_b = load_p(sb, w, &format!("{lp}.mlp.layers.1.bias"), false)?;
let _ = hidden; let up_mm = sb.m().mm(normed2, fc1_w);
let up = sb.m().add(up_mm, fc1_b);
let act = sb.m().gelu(up); let down_mm = sb.m().mm(act, fc2_w);
let down = sb.m().add(down_mm, fc2_b);
Ok(sb.m().add(x, down))
}
fn qpool_2x2(
sb: &mut Sam2Builder,
x: HirNodeId,
batch: usize,
h: usize,
w: usize,
c: usize,
) -> HirNodeId {
debug_assert!(
h.is_multiple_of(2) && w.is_multiple_of(2),
"Q-pool needs even spatial dims"
);
let f = DType::F32;
let rs_h = sb
.m()
.reshape_(x, vec![batch as i64, (h / 2) as i64, 2, w as i64, c as i64]);
let pool_h = max_reduce(
sb,
rs_h,
vec![2],
false,
Shape::new(&[batch, h / 2, w, c], f),
);
let rs_w = sb.m().reshape_(
pool_h,
vec![batch as i64, (h / 2) as i64, (w / 2) as i64, 2, c as i64],
);
max_reduce(
sb,
rs_w,
vec![3],
false,
Shape::new(&[batch, h / 2, w / 2, c], f),
)
}
#[allow(clippy::too_many_arguments)]
fn multi_scale_attention_windowed(
sb: &mut Sam2Builder,
w: &mut WeightMap,
lp: &str,
x: HirNodeId, h: usize,
wd: usize,
dim_in: usize,
dim_out: usize,
num_heads: usize,
head_dim: usize,
scale: f32,
ws: usize,
q_pool: bool,
qkv_bias: bool,
) -> Result<HirNodeId> {
let pad_h = (ws - h % ws) % ws;
let pad_w = (ws - wd % ws) % ws;
let hp = h + pad_h;
let wp = wd + pad_w;
let x_pad = if pad_h > 0 {
let z = zero_param(sb, &format!("{lp}.attn._pad_h"), &[1, pad_h, wd, dim_in]);
sb.m().concat_(vec![x, z], 1)
} else {
x
};
let x_pad = if pad_w > 0 {
let z = zero_param(sb, &format!("{lp}.attn._pad_w"), &[1, hp, pad_w, dim_in]);
sb.m().concat_(vec![x_pad, z], 2)
} else {
x_pad
};
let nh_w = hp / ws;
let nw_w = wp / ws;
let n_win = nh_w * nw_w;
let rs = sb.m().reshape_(
x_pad,
vec![
1,
nh_w as i64,
ws as i64,
nw_w as i64,
ws as i64,
dim_in as i64,
],
);
let perm = sb.m().transpose_(rs, vec![0, 1, 3, 2, 4, 5]);
let windowed = sb.m().reshape_(
perm,
vec![n_win as i64, ws as i64, ws as i64, dim_in as i64],
);
let attn_out = mask_unit_attention(
sb, w, lp, windowed, n_win, ws, ws, dim_in, dim_out, num_heads, head_dim, scale, q_pool,
qkv_bias,
)?;
let ws_new = if q_pool { ws / SAM2_Q_STRIDE } else { ws };
let hp_new = nh_w * ws_new;
let wp_new = nw_w * ws_new;
let r = sb.m().reshape_(
attn_out,
vec![
1,
nh_w as i64,
nw_w as i64,
ws_new as i64,
ws_new as i64,
dim_out as i64,
],
);
let p = sb.m().transpose_(r, vec![0, 1, 3, 2, 4, 5]);
let unp = sb
.m()
.reshape_(p, vec![1, hp_new as i64, wp_new as i64, dim_out as i64]);
let h_new = if q_pool { h / SAM2_Q_STRIDE } else { h };
let w_new = if q_pool { wd / SAM2_Q_STRIDE } else { wd };
let out = if hp_new != h_new {
sb.m().narrow_(unp, 1, 0, h_new)
} else {
unp
};
let out = if wp_new != w_new {
sb.m().narrow_(out, 2, 0, w_new)
} else {
out
};
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn multi_scale_attention_global(
sb: &mut Sam2Builder,
w: &mut WeightMap,
lp: &str,
x: HirNodeId,
h: usize,
wd: usize,
dim_in: usize,
dim_out: usize,
num_heads: usize,
head_dim: usize,
scale: f32,
q_pool: bool,
qkv_bias: bool,
) -> Result<HirNodeId> {
mask_unit_attention(
sb, w, lp, x, 1, h, wd, dim_in, dim_out, num_heads, head_dim, scale, q_pool, qkv_bias,
)
}
#[allow(clippy::too_many_arguments)]
fn mask_unit_attention(
sb: &mut Sam2Builder,
w: &mut WeightMap,
lp: &str,
x: HirNodeId, batch: usize,
h: usize,
wd: usize,
_dim_in: usize,
dim_out: usize,
num_heads: usize,
head_dim: usize,
scale: f32,
q_pool: bool,
qkv_bias: bool,
) -> Result<HirNodeId> {
let s_kv = h * wd;
let qkv_w = load_p(sb, w, &format!("{lp}.attn.qkv.weight"), true)?;
let qkv_b = if qkv_bias {
Some(load_p(sb, w, &format!("{lp}.attn.qkv.bias"), false)?)
} else {
None
};
let qkv = sb.m().mm(x, qkv_w);
let qkv = if let Some(bnode) = qkv_b {
sb.m().add(qkv, bnode)
} else {
qkv
};
let qkv5 = sb.m().reshape_(
qkv,
vec![
batch as i64,
s_kv as i64,
3,
num_heads as i64,
head_dim as i64,
],
);
let qkv_perm = sb.m().transpose_(qkv5, vec![2, 0, 1, 3, 4]);
let q_full = {
let s = sb.m().narrow_(qkv_perm, 0, 0, 1);
sb.m().reshape_(
s,
vec![batch as i64, s_kv as i64, num_heads as i64, head_dim as i64],
)
};
let k = {
let s = sb.m().narrow_(qkv_perm, 0, 1, 1);
sb.m().reshape_(
s,
vec![batch as i64, s_kv as i64, num_heads as i64, head_dim as i64],
)
};
let v = {
let s = sb.m().narrow_(qkv_perm, 0, 2, 1);
sb.m().reshape_(
s,
vec![batch as i64, s_kv as i64, num_heads as i64, head_dim as i64],
)
};
let (q, h_out, w_out, s_q) = if q_pool {
let qspat = sb.m().reshape_(
q_full,
vec![batch as i64, h as i64, wd as i64, dim_out as i64],
);
let qpooled = qpool_2x2(sb, qspat, batch, h, wd, dim_out);
let h2 = h / SAM2_Q_STRIDE;
let w2 = wd / SAM2_Q_STRIDE;
let s2 = h2 * w2;
let qflat = sb.m().reshape_(
qpooled,
vec![batch as i64, s2 as i64, num_heads as i64, head_dim as i64],
);
(qflat, h2, w2, s2)
} else {
(q_full, h, wd, s_kv)
};
let q_t = sb.m().transpose_(q, vec![0, 2, 1, 3]);
let k_t = sb.m().transpose_(k, vec![0, 2, 1, 3]);
let v_t = sb.m().transpose_(v, vec![0, 2, 1, 3]);
let q_flat = sb.m().reshape_(
q_t,
vec![(batch * num_heads) as i64, s_q as i64, head_dim as i64],
);
let k_flat = sb.m().reshape_(
k_t,
vec![(batch * num_heads) as i64, s_kv as i64, head_dim as i64],
);
let v_flat = sb.m().reshape_(
v_t,
vec![(batch * num_heads) as i64, s_kv as i64, head_dim as i64],
);
let scale_node = scalar_param(sb, &format!("{lp}.attn._scale"), scale);
let q_scaled = sb.m().mul(q_flat, scale_node);
let k_for_mm = sb.m().transpose_(k_flat, vec![0, 2, 1]); let scores = sb.m().mm(q_scaled, k_for_mm); let attn_w = sb.m().sm(scores, -1);
let attn_v = sb.m().mm(attn_w, v_flat);
let r = sb.m().reshape_(
attn_v,
vec![batch as i64, num_heads as i64, s_q as i64, head_dim as i64],
);
let r = sb.m().transpose_(r, vec![0, 2, 1, 3]);
let merged = sb.m().reshape_(
r,
vec![batch as i64, h_out as i64, w_out as i64, dim_out as i64],
);
let proj_w = load_p(sb, w, &format!("{lp}.attn.proj.weight"), true)?;
let proj_b = load_p(sb, w, &format!("{lp}.attn.proj.bias"), false)?;
let proj_mm = sb.m().mm(merged, proj_w);
Ok(sb.m().add(proj_mm, proj_b))
}
fn load_p(
sb: &mut Sam2Builder,
weights: &mut WeightMap,
key: &str,
transpose: bool,
) -> Result<HirNodeId> {
let (data, shape) = if transpose {
weights
.take_transposed(key)
.map_err(|e| anyhow!("transpose-load `{key}`: {e}"))?
} else {
weights
.take(key)
.map_err(|e| anyhow!("load `{key}`: {e}"))?
};
let id = sb.m().param(key, Shape::new(&shape, DType::F32));
sb.params.insert(key.to_string(), data);
Ok(id)
}
fn scalar_param(sb: &mut Sam2Builder, name: &str, value: f32) -> HirNodeId {
let id = sb.m().param(name, Shape::new(&[1], DType::F32));
sb.params.insert(name.to_string(), vec![value]);
id
}
fn zero_param(sb: &mut Sam2Builder, name: &str, shape: &[usize]) -> HirNodeId {
let n: usize = shape.iter().product();
let id = sb.m().param(name, Shape::new(shape, DType::F32));
sb.params.insert(name.to_string(), vec![0f32; n]);
id
}