use anyhow::{anyhow, bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};
use crate::config::SeAConfig;
use crate::nn::{embedding_mlp, scalar_const, ActivationKind, MlpSpec};
pub struct SeADescriptor {
pub descriptor: NodeId,
pub gr: NodeId,
pub dim_out: usize,
}
#[derive(Default, Clone, Copy)]
pub struct SeAExtraInputs {
pub exclude_mask: Option<NodeId>,
}
pub fn build_se_a_descriptor(
g: &mut Graph,
cfg: &SeAConfig,
env_mat_raw: NodeId,
atype_loc: NodeId,
nf: usize,
nloc: usize,
extra: SeAExtraInputs,
) -> Result<SeADescriptor> {
let activation = ActivationKind::parse(&cfg.activation_function)?;
let ntypes = cfg.ntypes();
let ng = cfg.ng();
let m2 = cfg.axis_neuron;
let nnei = cfg.nnei();
let sec = cfg.sel_cumsum();
let rr_shape = g.shape(env_mat_raw).clone();
if rr_shape.rank() != 4 {
bail!(
"se_e2_a: env-matrix input must have rank 4 [nf, nloc, nnei, 4], got rank {}",
rr_shape.rank()
);
}
let davg = g.param(
"descriptor.davg",
Shape::new(&[ntypes, nnei, 4], DType::F32),
);
let dstd = g.param(
"descriptor.dstd",
Shape::new(&[ntypes, nnei, 4], DType::F32),
);
let davg_g = g.gather_(davg, atype_loc, 0);
let dstd_g = g.gather_(dstd, atype_loc, 0);
let mut rr = g.sub(env_mat_raw, davg_g);
rr = g.div(rr, dstd_g);
if let Some(mask) = extra.exclude_mask {
let mask_shape = g.shape(mask).clone();
let mut dims: Vec<rlx_ir::Dim> = mask_shape.dims().to_vec();
dims.push(rlx_ir::Dim::Static(1));
let mask_4d_shape = Shape::from_dims(&dims, DType::F32);
let mask_4d = g.reshape(
mask,
vec![nf as i64, nloc as i64, nnei as i64, 1],
mask_4d_shape,
);
rr = g.mul(rr, mask_4d);
}
let descriptor_out = if cfg.type_one_side {
build_descriptor_type_one_side(g, cfg, activation, rr, &sec, nf, nloc, nnei, ng, m2)?
} else {
build_descriptor_type_two_side(
g, cfg, activation, rr, atype_loc, &sec, nf, nloc, nnei, ng, m2,
)?
};
Ok(descriptor_out)
}
fn build_descriptor_type_one_side(
g: &mut Graph,
cfg: &SeAConfig,
activation: ActivationKind,
rr: NodeId,
sec: &[usize],
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
m2: usize,
) -> Result<SeADescriptor> {
let ntypes = cfg.ntypes();
let mut acc: Option<NodeId> = None;
for t in 0..ntypes {
let start = sec[t];
let len = sec[t + 1] - start;
if len == 0 {
continue;
}
let rr_t = g.narrow_(rr, 2, start, len); let ss_t = g.narrow_(rr_t, 3, 0, 1); let prefix = format!("descriptor.embedding.{t}");
let mlp = MlpSpec {
param_prefix: &prefix,
in_dim: 1,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let gg_t = embedding_mlp(g, &mlp, ss_t);
let gg_t_t = g.transpose_(gg_t, vec![0, 1, 3, 2]); let contrib = g.mm(gg_t_t, rr_t);
acc = Some(match acc {
None => contrib,
Some(prev) => g.add(prev, contrib),
});
}
let gr_unscaled =
acc.ok_or_else(|| anyhow!("se_e2_a: empty selection (sum(sel) == 0)"))?;
Ok(finalize_descriptor(g, gr_unscaled, nf, nloc, nnei, ng, m2))
}
fn build_descriptor_type_two_side(
g: &mut Graph,
cfg: &SeAConfig,
activation: ActivationKind,
rr: NodeId,
atype_loc: NodeId,
sec: &[usize],
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
m2: usize,
) -> Result<SeADescriptor> {
let ntypes = cfg.ntypes();
let mut acc: Option<NodeId> = None;
for ti in 0..ntypes {
let mut sum_j: Option<NodeId> = None;
for tj in 0..ntypes {
let start = sec[tj];
let len = sec[tj + 1] - start;
if len == 0 {
continue;
}
let rr_j = g.narrow_(rr, 2, start, len); let ss_j = g.narrow_(rr_j, 3, 0, 1); let prefix = format!("descriptor.embedding.{ti}.{tj}");
let mlp = MlpSpec {
param_prefix: &prefix,
in_dim: 1,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let gg_ij = embedding_mlp(g, &mlp, ss_j);
let gg_ij_t = g.transpose_(gg_ij, vec![0, 1, 3, 2]);
let contrib = g.mm(gg_ij_t, rr_j); sum_j = Some(match sum_j {
None => contrib,
Some(prev) => g.add(prev, contrib),
});
}
let Some(contrib_ti) = sum_j else { continue };
let ti_const = type_index_const(g, ti as i32, nf, nloc);
let eq = g.binary(
rlx_ir::op::BinaryOp::Sub,
atype_loc,
ti_const,
g.shape(atype_loc).clone(),
);
let eq_f32 = g.cast(eq, DType::F32);
let abs = g.activation(
rlx_ir::op::Activation::Abs,
eq_f32,
g.shape(eq_f32).clone(),
);
let one = scalar_const(g, 1.0);
let clipped = g.binary(
rlx_ir::op::BinaryOp::Min,
abs,
one,
g.shape(abs).clone(),
);
let mask = g.sub(one, clipped);
let mask_4d_shape = Shape::new(&[nf, nloc, 1, 1], DType::F32);
let mask_4d = g.reshape(mask, vec![nf as i64, nloc as i64, 1, 1], mask_4d_shape);
let masked = g.mul(contrib_ti, mask_4d);
acc = Some(match acc {
None => masked,
Some(prev) => g.add(prev, masked),
});
}
let gr_unscaled =
acc.ok_or_else(|| anyhow!("se_e2_a: empty selection (sum(sel) == 0)"))?;
Ok(finalize_descriptor(g, gr_unscaled, nf, nloc, nnei, ng, m2))
}
fn finalize_descriptor(
g: &mut Graph,
gr_unscaled: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
m2: usize,
) -> SeADescriptor {
let inv_nnei = scalar_const(g, 1.0 / nnei as f32);
let gr = g.mul(gr_unscaled, inv_nnei);
let gr_lt = g.narrow_(gr, 2, 0, m2);
let gr_lt_t = g.transpose_(gr_lt, vec![0, 1, 3, 2]);
let d = g.mm(gr, gr_lt_t);
let dim_out = ng * m2;
let d_flat_shape = Shape::new(&[nf, nloc, dim_out], DType::F32);
let descriptor = g.reshape(
d,
vec![nf as i64, nloc as i64, dim_out as i64],
d_flat_shape,
);
let gr_vec = g.narrow_(gr, 3, 1, 3);
SeADescriptor {
descriptor,
gr: gr_vec,
dim_out,
}
}
fn type_index_const(g: &mut Graph, value: i32, nf: usize, nloc: usize) -> NodeId {
let total = nf * nloc;
let bytes: Vec<u8> = (0..total).flat_map(|_| value.to_le_bytes()).collect();
g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
Shape::new(&[nf, nloc], DType::I32),
)
}