use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::op::Op;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::nn::{embedding_mlp, ActivationKind, MlpSpec};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TypeEmbedConfig {
pub ntypes: usize,
#[serde(default = "default_neuron")]
pub neuron: Vec<usize>,
#[serde(default)]
pub resnet_dt: bool,
#[serde(default = "default_activation")]
pub activation_function: String,
#[serde(default)]
pub padding: bool,
#[serde(default)]
pub use_tebd_bias: bool,
}
fn default_neuron() -> Vec<usize> {
vec![8]
}
fn default_activation() -> String {
"tanh".into()
}
impl TypeEmbedConfig {
pub fn tebd_dim(&self) -> usize {
*self
.neuron
.last()
.expect("type_embed: empty neuron list")
}
pub fn table_rows(&self) -> usize {
self.ntypes + if self.padding { 1 } else { 0 }
}
}
pub fn build_type_embedding(
g: &mut Graph,
cfg: &TypeEmbedConfig,
param_prefix: &str,
) -> Result<NodeId> {
let activation = ActivationKind::parse(&cfg.activation_function)?;
let ntypes = cfg.ntypes;
let tebd = cfg.tebd_dim();
let mut id_data = vec![0f32; ntypes * ntypes];
for i in 0..ntypes {
id_data[i * ntypes + i] = 1.0;
}
let id_bytes: Vec<u8> = id_data.iter().flat_map(|v| v.to_le_bytes()).collect();
let id = g.add_node(
Op::Constant { data: id_bytes },
vec![],
Shape::new(&[ntypes, ntypes], DType::F32),
);
let prefix = format!("{param_prefix}.embedding");
let mlp = MlpSpec {
param_prefix: &prefix,
in_dim: ntypes,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let mut t = embedding_mlp(g, &mlp, id);
if cfg.padding {
let zero_bytes: Vec<u8> = vec![0u8; 4 * tebd];
let zero = g.add_node(
Op::Constant { data: zero_bytes },
vec![],
Shape::new(&[1, tebd], DType::F32),
);
let out_shape = Shape::new(&[ntypes + 1, tebd], DType::F32);
t = g.concat(vec![t, zero], 0, out_shape);
}
Ok(t)
}
pub fn gather_per_atom_embed(g: &mut Graph, table: NodeId, atype: NodeId) -> NodeId {
g.gather_(table, atype, 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn type_embed_builds() {
let cfg = TypeEmbedConfig {
ntypes: 3,
neuron: vec![16],
resnet_dt: false,
activation_function: "tanh".into(),
padding: true,
use_tebd_bias: false,
};
let mut g = Graph::new("type_embed");
let table = build_type_embedding(&mut g, &cfg, "type_embed").expect("build");
let atype = g.input("atype", Shape::new(&[1, 4], DType::I32));
let per_atom = gather_per_atom_embed(&mut g, table, atype);
assert_eq!(g.shape(per_atom).rank(), 3);
assert!(g.len() > 5);
}
}