use crate::interface::*;
use serde::Deserialize;
use std::collections::HashMap;
use toml::Table;
const PARAMETERS_TOML: &str = include_str!("parameters.toml");
#[derive(Debug, Clone, Deserialize, Default)]
struct D4Variants {
d4: D4MethodParams,
}
#[derive(Debug, Clone, Deserialize, Default)]
struct D4MethodParams {
#[serde(rename = "bj-eeq-two")]
bj_eeq_two: Option<Table>,
#[serde(rename = "bj-eeq-atm")]
bj_eeq_atm: Option<Table>,
#[serde(rename = "bj-eeq-mbd")]
bj_eeq_mbd: Option<Table>,
}
#[derive(Debug, Clone, Deserialize)]
struct ParameterDataBase {
default: DefaultSection,
parameter: HashMap<String, D4Variants>,
}
#[derive(Debug, Clone, Deserialize)]
struct DefaultSection {
#[allow(dead_code)]
d4: Vec<String>,
parameter: DefaultParameterSection,
}
#[derive(Debug, Clone, Deserialize)]
struct DefaultParameterSection {
d4: D4DefaultParams,
}
#[derive(Debug, Clone, Deserialize)]
struct D4DefaultParams {
#[serde(rename = "bj-eeq-two")]
bj_eeq_two: Table,
#[serde(rename = "bj-eeq-atm")]
bj_eeq_atm: Table,
#[serde(rename = "bj-eeq-mbd")]
bj_eeq_mbd: Table,
}
#[derive(Debug, Clone)]
pub struct DFTD4DampingParam {
pub param: DFTD4RationalDampingParam,
pub doi: Option<String>,
}
impl DFTD4ParamAPI for DFTD4DampingParam {
fn new_param_f(self) -> Result<DFTD4Param, DFTD4Error> {
self.param.new_param_f()
}
}
fn load_data_base() -> Result<ParameterDataBase, DFTD4Error> {
toml::from_str(PARAMETERS_TOML)
.map_err(|e| DFTD4Error::ParametersError(format!("TOML parsing error: {}", e)))
}
fn lookup_method<'a>(
db: &'a ParameterDataBase,
method: &str,
) -> Result<(String, &'a D4Variants), DFTD4Error> {
let method_lower = normalize_method(method);
if let Some(entry) = db.parameter.get(&method_lower) {
return Ok((method_lower, entry));
}
for (key, entry) in &db.parameter {
if normalize_method(key) == method_lower {
return Ok((key.clone(), entry));
}
}
Err(DFTD4Error::ParametersError(format!("Method '{}' not found in database", method)))
}
pub fn dftd4_get_damping_param(method: &str, version: &str) -> DFTD4DampingParam {
dftd4_get_damping_param_f(method, version).unwrap()
}
pub fn dftd4_get_damping_param_f(
method: &str,
version: &str,
) -> Result<DFTD4DampingParam, DFTD4Error> {
let db = load_data_base()?;
let version_normalized = normalize_version(version);
let (_, method_entry) = lookup_method(&db, method)?;
let (entry_raw, default_entry) = get_variant_entry(method_entry, &version_normalized, &db)?;
let merged = merge_tables(&entry_raw, &default_entry);
convert_to_damping_param(&merged)
}
#[allow(dead_code)]
pub(crate) fn get_merged_param_table(method: &str, version: &str) -> Result<Table, DFTD4Error> {
let db = load_data_base()?;
let version_normalized = normalize_version(version);
let (_, method_entry) = lookup_method(&db, method)?;
let (entry_raw, default_entry) = get_variant_entry(method_entry, &version_normalized, &db)?;
Ok(merge_tables(&entry_raw, &default_entry))
}
pub(crate) fn normalize_method(method: &str) -> String {
method.to_lowercase().replace(['-', '_', ' '], "")
}
#[allow(dead_code)]
pub(crate) fn get_default_param_table(version: &str) -> Result<Table, DFTD4Error> {
let db = load_data_base()?;
let version_normalized = normalize_version(version);
let (_, default_entry) = get_variant_entry_for_defaults(&version_normalized, &db)?;
Ok(default_entry)
}
pub fn dftd4_get_all_damping_params(version: &str) -> HashMap<String, DFTD4DampingParam> {
dftd4_get_all_damping_params_f(version).unwrap()
}
pub fn dftd4_get_all_damping_params_f(
version: &str,
) -> Result<HashMap<String, DFTD4DampingParam>, DFTD4Error> {
let db = load_data_base()?;
let version_normalized = normalize_version(version);
let (_, default_entry) = get_variant_entry_for_defaults(&version_normalized, &db)?;
let mut result = HashMap::new();
for (method, method_entry) in &db.parameter {
if let Ok((entry_raw, _)) = get_variant_entry(method_entry, &version_normalized, &db) {
let merged = merge_tables(&entry_raw, &default_entry);
if let Ok(param) = convert_to_damping_param(&merged) {
result.insert(method.clone(), param);
}
}
}
Ok(result)
}
pub fn dftd4_list_methods() -> Vec<String> {
let db = load_data_base().unwrap_or_else(|_| ParameterDataBase {
default: DefaultSection {
d4: vec!["bj-eeq-atm".to_string()],
parameter: DefaultParameterSection {
d4: D4DefaultParams {
bj_eeq_two: Table::new(),
bj_eeq_atm: Table::new(),
bj_eeq_mbd: Table::new(),
},
},
},
parameter: HashMap::new(),
});
db.parameter.keys().cloned().collect()
}
pub(crate) fn normalize_version(version: &str) -> String {
let version_lower = version.to_lowercase().replace(['-', '_', ' '], "");
match version_lower.as_str() {
"d4" | "d4bj" | "bj" | "bjeeqatm" | "atm" => "bj-eeq-atm",
"bjeeqtwo" | "two" => "bj-eeq-two",
"bjeeqmbd" | "mbd" => "bj-eeq-mbd",
_ => &version_lower,
}
.to_string()
}
fn get_variant_entry(
method_entry: &D4Variants,
version: &str,
db: &ParameterDataBase,
) -> Result<(Table, Table), DFTD4Error> {
let d4_params = &method_entry.d4;
let entry = match version {
"bj-eeq-two" => d4_params.bj_eeq_two.clone(),
"bj-eeq-atm" => d4_params.bj_eeq_atm.clone(),
"bj-eeq-mbd" => d4_params.bj_eeq_mbd.clone(),
_ => None,
};
let entry = entry.ok_or_else(|| {
DFTD4Error::ParametersError(format!("Variant '{}' not found for this method", version))
})?;
let default_entry = match version {
"bj-eeq-two" => db.default.parameter.d4.bj_eeq_two.clone(),
"bj-eeq-atm" => db.default.parameter.d4.bj_eeq_atm.clone(),
"bj-eeq-mbd" => db.default.parameter.d4.bj_eeq_mbd.clone(),
_ => Table::new(),
};
Ok((entry, default_entry))
}
fn get_variant_entry_for_defaults(
version: &str,
db: &ParameterDataBase,
) -> Result<(Option<Table>, Table), DFTD4Error> {
let default_entry = match version {
"bj-eeq-two" => db.default.parameter.d4.bj_eeq_two.clone(),
"bj-eeq-atm" => db.default.parameter.d4.bj_eeq_atm.clone(),
"bj-eeq-mbd" => db.default.parameter.d4.bj_eeq_mbd.clone(),
_ => {
return Err(DFTD4Error::ParametersError(format!(
"Variant '{version}' not found in defaults",
)))
},
};
Ok((None, default_entry))
}
pub(crate) fn merge_tables(entry: &Table, defaults: &Table) -> Table {
let mut merged = defaults.clone();
for (key, value) in entry {
merged.insert(key.clone(), value.clone());
}
merged
}
fn extract_doi(table: &Table) -> Option<String> {
table.get("doi").and_then(|v| v.as_str()).map(|s| s.to_string())
}
pub(crate) fn convert_to_damping_param(merged: &Table) -> Result<DFTD4DampingParam, DFTD4Error> {
let doi = extract_doi(merged);
let param = deserialize_table(merged)?;
Ok(DFTD4DampingParam { param, doi })
}
pub(crate) fn deserialize_table<T: for<'de> Deserialize<'de>>(
table: &Table,
) -> Result<T, DFTD4Error> {
T::deserialize(table.clone()).map_err(|e| e.into())
}
impl From<toml::de::Error> for DFTD4Error {
fn from(e: toml::de::Error) -> Self {
DFTD4Error::ParametersError(format!("Deserialization error: {e}"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_data_base() {
let db = load_data_base();
match &db {
Ok(db) => {
assert!(db.parameter.contains_key("b3lyp"));
assert!(db.parameter.contains_key("pbe"));
assert!(db.parameter.contains_key("r2scan"));
},
Err(e) => {
println!("Error: {:?}", e);
panic!("TOML parsing failed");
},
}
}
#[test]
fn test_normalize_version() {
assert_eq!(normalize_version("d4"), "bj-eeq-atm");
assert_eq!(normalize_version("d4bj"), "bj-eeq-atm");
assert_eq!(normalize_version("bj"), "bj-eeq-atm");
assert_eq!(normalize_version("atm"), "bj-eeq-atm");
assert_eq!(normalize_version("two"), "bj-eeq-two");
assert_eq!(normalize_version("mbd"), "bj-eeq-mbd");
assert_eq!(normalize_version("bj-eeq-atm"), "bj-eeq-atm");
}
}