use anyhow::Result;
use rlx_ir::Graph;
use crate::descriptor_dpa1::{
build_dpa1_descriptor, DPA1Config, DPA1Descriptor, DPA1Inputs,
};
use crate::descriptor_se_t_tebd::TebdInputMode;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SeAttenV2Config {
pub rcut: f64,
pub rcut_smth: f64,
pub sel: Vec<usize>,
pub ntypes: usize,
#[serde(default = "default_neuron")]
pub neuron: Vec<usize>,
#[serde(default = "default_axis_neuron")]
pub axis_neuron: usize,
#[serde(default = "default_tebd_dim")]
pub tebd_dim: usize,
#[serde(default)]
pub resnet_dt: bool,
#[serde(default = "default_attn")]
pub attn: usize,
#[serde(default = "default_attn_layer")]
pub attn_layer: usize,
#[serde(default = "default_true")]
pub attn_dotr: bool,
#[serde(default = "default_num_heads")]
pub num_heads: usize,
#[serde(default = "default_true")]
pub normalize: bool,
#[serde(default = "default_scaling_factor")]
pub scaling_factor: f32,
#[serde(default = "default_ln_eps")]
pub ln_eps: f32,
#[serde(default)]
pub type_one_side: bool,
#[serde(default = "default_activation")]
pub activation_function: String,
#[serde(default = "default_true")]
pub concat_output_tebd: bool,
}
fn default_neuron() -> Vec<usize> {
vec![25, 50, 100]
}
fn default_axis_neuron() -> usize {
8
}
fn default_tebd_dim() -> usize {
8
}
fn default_attn() -> usize {
128
}
fn default_attn_layer() -> usize {
2
}
fn default_num_heads() -> usize {
1
}
fn default_ln_eps() -> f32 {
1e-5
}
fn default_true() -> bool {
true
}
fn default_scaling_factor() -> f32 {
1.0
}
fn default_activation() -> String {
"tanh".into()
}
impl SeAttenV2Config {
pub fn to_dpa1(&self) -> DPA1Config {
DPA1Config {
rcut: self.rcut,
rcut_smth: self.rcut_smth,
sel: self.sel.clone(),
ntypes: self.ntypes,
neuron: self.neuron.clone(),
axis_neuron: self.axis_neuron,
tebd_dim: self.tebd_dim,
tebd_input_mode: TebdInputMode::Strip,
resnet_dt: self.resnet_dt,
attn: self.attn,
attn_layer: self.attn_layer,
attn_dotr: self.attn_dotr,
num_heads: self.num_heads,
normalize: self.normalize,
scaling_factor: self.scaling_factor,
ln_eps: self.ln_eps,
smooth: true,
type_one_side: self.type_one_side,
activation_function: self.activation_function.clone(),
concat_output_tebd: self.concat_output_tebd,
}
}
}
pub fn build_se_atten_v2_descriptor(
g: &mut Graph,
cfg: &SeAttenV2Config,
inputs: DPA1Inputs,
nf: usize,
nloc: usize,
) -> Result<DPA1Descriptor> {
let dpa1 = cfg.to_dpa1();
build_dpa1_descriptor(g, &dpa1, inputs, nf, nloc)
}