use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::{Activation, BinaryOp, Op};
use rlx_ir::{DType, Graph, NodeId, Shape};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ActivationKind {
Tanh,
Relu,
Relu6,
Gelu,
Sigmoid,
Softplus,
Silu,
Linear,
}
impl ActivationKind {
pub fn parse(s: &str) -> Result<Self> {
let s = s.to_ascii_lowercase();
Ok(match s.as_str() {
"tanh" => Self::Tanh,
"relu" => Self::Relu,
"relu6" => Self::Relu6,
"gelu" | "gelu_tf" => Self::Gelu,
"sigmoid" => Self::Sigmoid,
"softplus" => Self::Softplus,
"silu" => Self::Silu,
"none" | "linear" => Self::Linear,
_ => bail!("unsupported activation: {s}"),
})
}
}
pub fn apply_activation(g: &mut Graph, kind: ActivationKind, x: NodeId) -> NodeId {
match kind {
ActivationKind::Tanh => g.tanh(x),
ActivationKind::Relu => g.relu(x),
ActivationKind::Gelu => g.gelu_approx(x),
ActivationKind::Sigmoid => {
let neg = g.neg(x);
let exp = g.exp(neg);
let one = scalar_const(g, 1.0);
let denom = g.add(exp, one);
let one2 = scalar_const(g, 1.0);
g.div(one2, denom)
}
ActivationKind::Softplus => {
let exp = g.exp(x);
let one = scalar_const(g, 1.0);
let sum = g.add(exp, one);
g.activation(Activation::Log, sum, g.shape(sum).clone())
}
ActivationKind::Silu => g.silu(x),
ActivationKind::Relu6 => {
let zero = scalar_const(g, 0.0);
let six = scalar_const(g, 6.0);
let s = g.shape(x).clone();
let pos = g.binary(BinaryOp::Max, x, zero, s.clone());
g.binary(BinaryOp::Min, pos, six, s)
}
ActivationKind::Linear => x,
}
}
pub fn scalar_const(g: &mut Graph, v: f32) -> NodeId {
let bytes = v.to_le_bytes().to_vec();
g.add_node(
Op::Constant { data: bytes },
vec![],
Shape::new(&[1], DType::F32),
)
}
pub fn tensor_const(g: &mut Graph, data: &[f32], shape: &[usize]) -> NodeId {
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
g.add_node(
Op::Constant { data: bytes },
vec![],
Shape::new(shape, DType::F32),
)
}
pub struct DenseLayerSpec<'a> {
pub param_prefix: &'a str,
pub in_dim: usize,
pub out_dim: usize,
pub activation: ActivationKind,
pub use_bias: bool,
pub resnet_dt: bool,
pub resnet: bool,
}
pub fn dense_layer(g: &mut Graph, spec: &DenseLayerSpec, x: NodeId) -> NodeId {
let w = g.param(
format!("{}.w", spec.param_prefix),
Shape::new(&[spec.in_dim, spec.out_dim], DType::F32),
);
let mm = g.mm(x, w);
let pre_act = if spec.use_bias {
let b = g.param(
format!("{}.b", spec.param_prefix),
Shape::new(&[spec.out_dim], DType::F32),
);
g.add(mm, b)
} else {
mm
};
let mut y = apply_activation(g, spec.activation, pre_act);
if spec.resnet_dt {
let idt = g.param(
format!("{}.idt", spec.param_prefix),
Shape::new(&[spec.out_dim], DType::F32),
);
y = g.mul(y, idt);
}
if spec.resnet {
if spec.out_dim == spec.in_dim {
y = g.add(y, x);
} else if spec.out_dim == 2 * spec.in_dim {
let x_shape = g.shape(x).clone();
let last_axis = x_shape.rank() - 1;
let mut new_dims: Vec<rlx_ir::Dim> = x_shape.dims().to_vec();
new_dims[last_axis] = match new_dims[last_axis] {
rlx_ir::Dim::Static(n) => rlx_ir::Dim::Static(n * 2),
other => other,
};
let doubled_shape = Shape::from_dims(&new_dims, x_shape.dtype());
let xx = g.concat(vec![x, x], last_axis, doubled_shape);
y = g.add(y, xx);
}
}
y
}
pub struct MlpSpec<'a> {
pub param_prefix: &'a str,
pub in_dim: usize,
pub neuron: &'a [usize],
pub activation: ActivationKind,
pub resnet_dt: bool,
}
pub fn embedding_mlp(g: &mut Graph, spec: &MlpSpec, x_in: NodeId) -> NodeId {
let mut x = x_in;
let mut in_dim = spec.in_dim;
for (l, &out_dim) in spec.neuron.iter().enumerate() {
let layer = format!("{}.layer{l}", spec.param_prefix);
let layer_spec = DenseLayerSpec {
param_prefix: &layer,
in_dim,
out_dim,
activation: spec.activation,
use_bias: true,
resnet_dt: spec.resnet_dt,
resnet: true,
};
x = dense_layer(g, &layer_spec, x);
in_dim = out_dim;
}
x
}
pub fn fitting_mlp(
g: &mut Graph,
spec: &MlpSpec,
x_in: NodeId,
final_dim: usize,
final_prefix: &str,
) -> NodeId {
let hidden = embedding_mlp(g, spec, x_in);
let in_dim = *spec.neuron.last().unwrap_or(&spec.in_dim);
let final_spec = DenseLayerSpec {
param_prefix: final_prefix,
in_dim,
out_dim: final_dim,
activation: ActivationKind::Linear,
use_bias: true,
resnet_dt: false,
resnet: false,
};
dense_layer(g, &final_spec, hidden)
}