use crate::error::XcError;
use crate::families::XcEval;
use crate::io::{XcInput, XcResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Spin {
Unpolarized,
Polarized,
}
impl Spin {
pub fn channels(self) -> usize {
match self {
Spin::Unpolarized => 1,
Spin::Polarized => 2,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Family {
Lda,
Gga,
Mgga,
HybGga,
HybMgga,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Kind {
Exchange,
Correlation,
ExchangeCorrelation,
Kinetic,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum FunctionalId {
LdaX,
LdaCPw,
LdaCVwn,
LdaCVwn3,
LdaCVwnRpa,
GgaXPbe,
GgaXB88,
GgaCPbe,
GgaCLyp,
GgaXPbeR,
GgaXPbeSol,
GgaXRpbe,
GgaCPbeSol,
MggaXTpss,
MggaCTpss,
MggaXR2scan,
MggaCR2scan,
MggaXM06L,
MggaCM06L,
HybGgaXcB3lyp,
HybGgaXcPbeh,
HybGgaXcB3lyp5,
}
impl FunctionalId {
pub const ALL: &'static [FunctionalId] = {
use FunctionalId::*;
&[
LdaX,
LdaCPw,
LdaCVwn,
LdaCVwn3,
LdaCVwnRpa,
GgaXPbe,
GgaXB88,
GgaCPbe,
GgaCLyp,
GgaXPbeR,
GgaXPbeSol,
GgaXRpbe,
GgaCPbeSol,
MggaXTpss,
MggaCTpss,
MggaXR2scan,
MggaCR2scan,
MggaXM06L,
MggaCM06L,
HybGgaXcB3lyp,
HybGgaXcPbeh,
HybGgaXcB3lyp5,
]
};
pub fn as_u32(self) -> u32 {
use FunctionalId::*;
match self {
LdaX => 1,
LdaCPw => 12,
LdaCVwn => 7,
LdaCVwn3 => 30,
LdaCVwnRpa => 8,
GgaXPbe => 101,
GgaXB88 => 106,
GgaCPbe => 130,
GgaCLyp => 131,
GgaXPbeR => 102,
GgaXPbeSol => 116,
GgaXRpbe => 117,
GgaCPbeSol => 133,
MggaXTpss => 202,
MggaCTpss => 231,
MggaXR2scan => 497,
MggaCR2scan => 498,
MggaXM06L => 203,
MggaCM06L => 233,
HybGgaXcB3lyp => 402,
HybGgaXcPbeh => 406,
HybGgaXcB3lyp5 => 475,
}
}
pub fn from_u32(id: u32) -> Option<Self> {
FunctionalId::ALL.iter().copied().find(|f| f.as_u32() == id)
}
pub fn name(self) -> &'static str {
use FunctionalId::*;
match self {
LdaX => "lda_x",
LdaCPw => "lda_c_pw",
LdaCVwn => "lda_c_vwn",
LdaCVwn3 => "lda_c_vwn_3",
LdaCVwnRpa => "lda_c_vwn_rpa",
GgaXPbe => "gga_x_pbe",
GgaXB88 => "gga_x_b88",
GgaCPbe => "gga_c_pbe",
GgaCLyp => "gga_c_lyp",
GgaXPbeR => "gga_x_pbe_r",
GgaXPbeSol => "gga_x_pbe_sol",
GgaXRpbe => "gga_x_rpbe",
GgaCPbeSol => "gga_c_pbe_sol",
MggaXTpss => "mgga_x_tpss",
MggaCTpss => "mgga_c_tpss",
MggaXR2scan => "mgga_x_r2scan",
MggaCR2scan => "mgga_c_r2scan",
MggaXM06L => "mgga_x_m06_l",
MggaCM06L => "mgga_c_m06_l",
HybGgaXcB3lyp => "hyb_gga_xc_b3lyp",
HybGgaXcPbeh => "hyb_gga_xc_pbeh",
HybGgaXcB3lyp5 => "hyb_gga_xc_b3lyp5",
}
}
pub fn from_name(name: &str) -> Option<Self> {
use FunctionalId::*;
Some(match name {
"lda_x" | "slater" => LdaX,
"lda_c_pw" | "pw92" | "pw" => LdaCPw,
"lda_c_vwn" | "lda_c_vwn_5" => LdaCVwn,
"lda_c_vwn_3" => LdaCVwn3,
"lda_c_vwn_rpa" => LdaCVwnRpa,
"gga_x_pbe" => GgaXPbe,
"gga_x_b88" => GgaXB88,
"gga_c_pbe" => GgaCPbe,
"gga_c_lyp" => GgaCLyp,
"gga_x_pbe_r" | "revpbe" => GgaXPbeR,
"gga_x_pbe_sol" => GgaXPbeSol,
"gga_x_rpbe" => GgaXRpbe,
"gga_c_pbe_sol" => GgaCPbeSol,
"mgga_x_tpss" => MggaXTpss,
"mgga_c_tpss" => MggaCTpss,
"mgga_x_r2scan" => MggaXR2scan,
"mgga_c_r2scan" => MggaCR2scan,
"mgga_x_m06_l" | "mgga_x_m06l" => MggaXM06L,
"mgga_c_m06_l" | "mgga_c_m06l" => MggaCM06L,
"hyb_gga_xc_b3lyp" => HybGgaXcB3lyp,
"hyb_gga_xc_pbeh" | "hyb_gga_xc_pbe0" | "pbe0" => HybGgaXcPbeh,
"hyb_gga_xc_b3lyp5" => HybGgaXcB3lyp5,
_ => return None,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub struct CamParams {
pub omega: f64,
pub alpha: f64,
pub beta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub struct Vv10Params {
pub b: f64,
pub c: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub struct HybridInfo {
pub exx_fraction: f64,
pub cam: Option<CamParams>,
pub vv10: Option<Vv10Params>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub struct FunctionalInfo {
pub id: Option<FunctionalId>,
pub name: &'static str,
pub family: Family,
pub kind: Kind,
pub needs_sigma: bool,
pub needs_lapl: bool,
pub needs_tau: bool,
pub dens_threshold: f64,
pub hybrid: Option<HybridInfo>,
}
pub struct Functional {
spin: Spin,
eval: Box<dyn XcEval>,
}
impl Functional {
pub fn new(id: FunctionalId, spin: Spin) -> Result<Self, XcError> {
let eval = crate::functionals::build(id)?;
Ok(Self { spin, eval })
}
pub fn by_name(name: &str, spin: Spin) -> Result<Self, XcError> {
let id = FunctionalId::from_name(name).ok_or(XcError::UnknownFunctional)?;
Self::new(id, spin)
}
pub fn info(&self) -> &FunctionalInfo {
self.eval.info()
}
pub fn spin(&self) -> Spin {
self.spin
}
pub fn exx_fraction(&self) -> f64 {
self.info().hybrid.map_or(0.0, |h| h.exx_fraction)
}
pub fn eval(&self, np: usize, input: &XcInput) -> Result<XcResult, XcError> {
self.eval.eval(self.spin, np, input)
}
pub fn eval_fxc(&self, np: usize, input: &XcInput) -> Result<XcResult, XcError> {
self.eval.eval_fxc(self.spin, np, input)
}
pub fn mix(parts: Vec<(f64, Functional)>) -> Result<Functional, XcError> {
let spin = parts.first().ok_or(XcError::SpinMismatch)?.1.spin;
if parts.iter().any(|(_, f)| f.spin != spin) {
return Err(XcError::SpinMismatch);
}
let exx = parts.iter().map(|(w, f)| w * f.exx_fraction()).sum();
let info = FunctionalInfo {
id: None,
name: "mixed",
family: Family::HybGga,
kind: Kind::ExchangeCorrelation,
needs_sigma: parts.iter().any(|(_, f)| f.info().needs_sigma),
needs_lapl: parts.iter().any(|(_, f)| f.info().needs_lapl),
needs_tau: parts.iter().any(|(_, f)| f.info().needs_tau),
dens_threshold: parts
.iter()
.map(|(_, f)| f.info().dens_threshold)
.fold(f64::INFINITY, f64::min),
hybrid: Some(HybridInfo {
exx_fraction: exx,
cam: None,
vv10: None,
}),
};
let weighted: Vec<(f64, Box<dyn XcEval>)> =
parts.into_iter().map(|(w, f)| (w, f.eval)).collect();
Ok(Functional {
spin,
eval: mixed_eval(weighted, info),
})
}
}
pub(crate) fn mixed_eval(
parts: Vec<(f64, Box<dyn XcEval>)>,
info: FunctionalInfo,
) -> Box<dyn XcEval> {
Box::new(MixEval { parts, info })
}
struct MixEval {
parts: Vec<(f64, Box<dyn XcEval>)>,
info: FunctionalInfo,
}
impl XcEval for MixEval {
fn info(&self) -> &FunctionalInfo {
&self.info
}
fn eval(&self, spin: Spin, np: usize, input: &XcInput) -> Result<XcResult, XcError> {
let mut acc = XcResult::default();
for (w, part) in &self.parts {
accumulate(&mut acc, *w, &part.eval(spin, np, input)?);
}
Ok(acc)
}
fn eval_fxc(&self, spin: Spin, np: usize, input: &XcInput) -> Result<XcResult, XcError> {
let mut acc = XcResult::default();
for (w, part) in &self.parts {
accumulate(&mut acc, *w, &part.eval_fxc(spin, np, input)?);
}
Ok(acc)
}
}
fn accumulate(acc: &mut XcResult, w: f64, r: &XcResult) {
add_scaled(&mut acc.exc, w, &r.exc);
add_scaled(&mut acc.vrho, w, &r.vrho);
add_scaled(&mut acc.vsigma, w, &r.vsigma);
add_scaled(&mut acc.vtau, w, &r.vtau);
add_scaled(&mut acc.vlapl, w, &r.vlapl);
add_scaled(&mut acc.v2rho2, w, &r.v2rho2);
add_scaled(&mut acc.v2rhosigma, w, &r.v2rhosigma);
add_scaled(&mut acc.v2sigma2, w, &r.v2sigma2);
add_scaled(&mut acc.v2rhotau, w, &r.v2rhotau);
add_scaled(&mut acc.v2sigmatau, w, &r.v2sigmatau);
add_scaled(&mut acc.v2tau2, w, &r.v2tau2);
}
fn add_scaled(dst: &mut Vec<f64>, w: f64, src: &[f64]) {
if src.len() > dst.len() {
dst.resize(src.len(), 0.0);
}
for (d, s) in dst.iter_mut().zip(src) {
*d += w * s;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::XcInput;
#[test]
fn id_roundtrips_and_matches_libxc_numbers() {
for &id in FunctionalId::ALL {
assert_eq!(FunctionalId::from_u32(id.as_u32()), Some(id));
assert_eq!(FunctionalId::from_name(id.name()), Some(id));
}
assert_eq!(FunctionalId::GgaCPbe.as_u32(), 130);
assert_eq!(FunctionalId::HybGgaXcB3lyp5.as_u32(), 475);
assert_eq!(
FunctionalId::from_name("pbe0"),
Some(FunctionalId::HybGgaXcPbeh)
);
}
#[test]
fn mix_single_weight_one_is_identity() {
let n = [0.7_f64];
let s = [0.2_f64];
let plain = Functional::new(FunctionalId::GgaXPbe, Spin::Unpolarized).unwrap();
let want = plain.eval(1, &XcInput::gga(&n, &s)).unwrap();
let mixed = Functional::mix(vec![(
1.0,
Functional::new(FunctionalId::GgaXPbe, Spin::Unpolarized).unwrap(),
)])
.unwrap();
let got = mixed.eval(1, &XcInput::gga(&n, &s)).unwrap();
assert_eq!(got.exc, want.exc);
assert_eq!(got.vrho, want.vrho);
assert_eq!(got.vsigma, want.vsigma);
}
#[test]
fn mix_accumulates_linearly_and_matches_fd() {
let (wa, wb) = (0.25_f64, 0.75_f64);
let build = || {
Functional::mix(vec![
(
wa,
Functional::new(FunctionalId::LdaX, Spin::Unpolarized).unwrap(),
),
(
wb,
Functional::new(FunctionalId::GgaXPbe, Spin::Unpolarized).unwrap(),
),
])
.unwrap()
};
let mixed = build();
let lda = Functional::new(FunctionalId::LdaX, Spin::Unpolarized).unwrap();
let pbe = Functional::new(FunctionalId::GgaXPbe, Spin::Unpolarized).unwrap();
for &(n, s) in &[(0.5_f64, 0.1_f64), (2.0, 0.7), (10.0, 5.0)] {
let rho = [n];
let sg = [s];
let inp = XcInput::gga(&rho, &sg);
let m = mixed.eval(1, &inp).unwrap();
let l = lda.eval(1, &XcInput::lda(&rho)).unwrap();
let p = pbe.eval(1, &inp).unwrap();
assert!((m.exc[0] - (wa * l.exc[0] + wb * p.exc[0])).abs() <= 1e-14 * m.exc[0].abs());
assert!(
(m.vrho[0] - (wa * l.vrho[0] + wb * p.vrho[0])).abs() <= 1e-14 * m.vrho[0].abs()
);
assert_eq!(m.vsigma.len(), 1);
assert!((m.vsigma[0] - wb * p.vsigma[0]).abs() <= 1e-14 * m.vsigma[0].abs());
let edens =
|n: f64, s: f64| n * mixed.eval(1, &XcInput::gga(&[n], &[s])).unwrap().exc[0];
let hn = 1e-6 * n;
let hs = 1e-6 * s;
let fdn = (edens(n + hn, s) - edens(n - hn, s)) / (2.0 * hn);
let fds = (edens(n, s + hs) - edens(n, s - hs)) / (2.0 * hs);
assert!((m.vrho[0] - fdn).abs() <= 1e-6 * m.vrho[0].abs().max(1.0));
assert!((m.vsigma[0] - fds).abs() <= 1e-6 * m.vsigma[0].abs().max(1.0));
}
assert_eq!(mixed.exx_fraction(), 0.0);
}
#[test]
fn mix_spin_mismatch_errors() {
let res = Functional::mix(vec![
(
0.5,
Functional::new(FunctionalId::LdaX, Spin::Unpolarized).unwrap(),
),
(
0.5,
Functional::new(FunctionalId::LdaX, Spin::Polarized).unwrap(),
),
]);
assert!(matches!(res, Err(XcError::SpinMismatch)));
}
}