use std::collections::HashMap;
use std::path::PathBuf;
use std::{env, fs};
const CXX_STANDARD: &str = "c++17";
const XSF_INCLUDE: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/xsf/include");
const WRAPPER_NAME: &str = "xsf_wrapper";
const WRAPPER_PREAMBLE: &str = "// generated by build.rs -- do not edit\n\n";
const WRAPPER_INCLUDES: &[&str] = &[
"airy.h",
"alg.h",
"bessel.h",
"beta.h",
"binom.h",
"cdflib.h",
"digamma.h",
"ellip.h",
"erf.h",
"evalpoly.h",
"exp.h",
"expint.h",
"fp_error_metrics.h",
"fresnel.h",
"gamma.h",
"hyp2f1.h",
"iv_ratio.h",
"kelvin.h",
"lambertw.h",
"legendre.h",
"log_exp.h",
"log.h",
"loggamma.h",
"mathieu.h",
"par_cyl.h",
"sici.h",
"specfun.h",
"sph_bessel.h",
"sph_harm.h",
"sphd_wave.h",
"stats.h",
"struve.h",
"trig.h",
"wright_bessel.h",
"zeta.h",
];
const WRAPPER_SPECS: &[(&str, &str)] = &[
("airy", "d->dddd"),
("airy", "D->DDDD"),
("airye", "d->dddd"),
("airye", "D->DDDD"),
("itairy", "d->dddd"),
("airyb*", "d->dddd"),
("airyzo*", "ii->dddd"),
("cbrt", "d->d"),
("it1j0y0", "d->dd"),
("it2j0y0", "d->dd"),
("it1i0k0", "d->dd"),
("it2i0k0", "d->dd"),
("cyl_bessel_j", "dd->d"),
("cyl_bessel_j", "dD->D"),
("cyl_bessel_je", "dd->d"),
("cyl_bessel_je", "dD->D"),
("cyl_bessel_j0", "d->d"),
("cyl_bessel_j1", "d->d"),
("cyl_bessel_y", "dd->d"),
("cyl_bessel_y", "dD->D"),
("cyl_bessel_ye", "dd->d"),
("cyl_bessel_ye", "dD->D"),
("cyl_bessel_y0", "d->d"),
("cyl_bessel_y1", "d->d"),
("cyl_bessel_i", "dd->d"),
("cyl_bessel_i", "dD->D"),
("cyl_bessel_ie", "dd->d"),
("cyl_bessel_ie", "dD->D"),
("cyl_bessel_i0", "d->d"),
("cyl_bessel_i0e", "d->d"),
("cyl_bessel_i1", "d->d"),
("cyl_bessel_i1e", "d->d"),
("cyl_bessel_k", "dd->d"),
("cyl_bessel_k", "dD->D"),
("cyl_bessel_ke", "dd->d"),
("cyl_bessel_ke", "dD->D"),
("cyl_bessel_k0", "d->d"),
("cyl_bessel_k0e", "d->d"),
("cyl_bessel_k1", "d->d"),
("cyl_bessel_k1e", "d->d"),
("cyl_hankel_1", "dD->D"),
("cyl_hankel_1e", "dD->D"),
("cyl_hankel_2", "dD->D"),
("cyl_hankel_2e", "dD->D"),
("besselpoly", "ddd->d"),
("beta", "dd->d"),
("betaln", "dd->d"),
("binom", "dd->d"),
("gdtrib", "ddd->d"),
("digamma", "d->d"),
("digamma", "D->D"),
("ellipj", "dd->dddd"),
("ellipk", "d->d"),
("ellipkm1", "d->d"),
("ellipkinc", "dd->d"),
("ellipe", "d->d"),
("ellipeinc", "dd->d"),
("erf", "d->d"),
("erf", "D->D"),
("erfc", "d->d"),
("erfc", "D->D"),
("erfcx", "d->d"),
("erfcx", "D->D"),
("erfi", "d->d"),
("erfi", "D->D"),
("voigt_profile", "ddd->d"),
("wofz", "D->D"),
("dawsn", "d->d"),
("dawsn", "D->D"),
("expm1", "d->d"),
("expm1", "D->D"),
("exp2", "d->d"),
("exp10", "d->d"),
("exp1", "d->d"),
("exp1", "D->D"),
("expi", "d->d"),
("expi", "D->D"),
("scaled_exp1", "d->d"),
("extended_absolute_error", "dd->d"),
("extended_absolute_error", "DD->d"),
("extended_relative_error", "dd->d"),
("extended_relative_error", "DD->d"),
("fresnel", "d->dd"),
("fresnel", "D->DD"),
("modified_fresnel_plus", "d->DD"),
("modified_fresnel_minus", "d->DD"),
("gamma", "d->d"),
("gamma", "D->D"),
("gammaln", "d->d"),
("gammasgn", "d->d"),
("gammainc", "dd->d"),
("gammaincinv", "dd->d"),
("gammaincc", "dd->d"),
("gammainccinv", "dd->d"),
("hyp2f1", "dddd->d"),
("hyp2f1", "dddD->D"),
("iv_ratio", "dd->d"),
("iv_ratio_c", "dd->d"),
("ber", "d->d"),
("bei", "d->d"),
("ker", "d->d"),
("kei", "d->d"),
("berp", "d->d"),
("beip", "d->d"),
("kerp", "d->d"),
("keip", "d->d"),
("kelvin", "d->DDDD"),
("lambertw", "Dld->D"),
("legendre_p", "id->d"),
("legendre_p", "iD->D"),
("sph_legendre_p", "iid->d"),
("sph_legendre_p", "iiD->D"),
("expit", "d->d"),
("exprel", "d->d"),
("logit", "d->d"),
("log_expit", "d->d"),
("log1mexp", "d->d"),
("log1p", "d->d"),
("log1p", "D->D"),
("log1pmx", "d->d"),
("xlogy", "dd->d"),
("xlogy", "DD->D"),
("xlog1py", "dd->d"),
("xlog1py", "DD->D"),
("loggamma", "d->d"),
("loggamma", "D->D"),
("rgamma", "d->d"),
("rgamma", "D->D"),
("cem_cva", "dd->d"),
("sem_cva", "dd->d"),
("cem", "ddd->dd"),
("sem", "ddd->dd"),
("mcm1", "ddd->dd"),
("msm1", "ddd->dd"),
("mcm2", "ddd->dd"),
("msm2", "ddd->dd"),
("pbwa", "dd->dd"),
("pbdv", "dd->dd"),
("pbvv", "dd->dd"),
("sici", "d->dd"),
("sici", "D->DD"),
("shichi", "d->dd"),
("shichi", "D->DD"),
("hyp1f1", "ddd->d"),
("hyp1f1", "ddD->D"),
("hypu", "ddd->d"),
("pmv", "ddd->d"),
("sph_bessel_j", "ld->d"),
("sph_bessel_j", "lD->D"),
("sph_bessel_j_jac", "ld->d"),
("sph_bessel_j_jac", "lD->D"),
("sph_bessel_y", "ld->d"),
("sph_bessel_y", "lD->D"),
("sph_bessel_y_jac", "ld->d"),
("sph_bessel_y_jac", "lD->D"),
("sph_bessel_i", "ld->d"),
("sph_bessel_i", "lD->D"),
("sph_bessel_i_jac", "ld->d"),
("sph_bessel_i_jac", "lD->D"),
("sph_bessel_k", "ld->d"),
("sph_bessel_k", "lD->D"),
("sph_bessel_k_jac", "ld->d"),
("sph_bessel_k_jac", "lD->D"),
("sph_harm_y", "iidd->D"),
("prolate_segv", "ddd->d"),
("prolate_aswfa_nocv", "dddd->dd"),
("prolate_radial1_nocv", "dddd->dd"),
("prolate_radial2_nocv", "dddd->dd"),
("prolate_aswfa", "ddddd->dd"),
("prolate_radial1", "ddddd->dd"),
("prolate_radial2", "ddddd->dd"),
("oblate_segv", "ddd->d"),
("oblate_aswfa_nocv", "dddd->dd"),
("oblate_radial1_nocv", "dddd->dd"),
("oblate_radial2_nocv", "dddd->dd"),
("oblate_radial1", "ddddd->dd"),
("oblate_radial2", "ddddd->dd"),
("oblate_aswfa", "ddddd->dd"),
("bdtr", "did->d"),
("bdtrc", "did->d"),
("bdtri", "did->d"),
("chdtr", "dd->d"),
("chdtrc", "dd->d"),
("chdtri", "dd->d"),
("fdtr", "ddd->d"),
("fdtrc", "ddd->d"),
("fdtri", "ddd->d"),
("gdtr", "ddd->d"),
("gdtrc", "ddd->d"),
("kolmogorov", "d->d"),
("kolmogc", "d->d"),
("kolmogi", "d->d"),
("kolmogci", "d->d"),
("kolmogp", "d->d"),
("ndtr", "d->d"),
("ndtr", "D->D"),
("ndtri", "d->d"),
("log_ndtr", "d->d"),
("log_ndtr", "D->D"),
("nbdtr", "iid->d"),
("nbdtrc", "iid->d"),
("nbdtri", "iid->d"),
("owens_t", "dd->d"),
("pdtr", "dd->d"),
("pdtrc", "dd->d"),
("pdtri", "id->d"),
("smirnov", "id->d"),
("smirnovc", "id->d"),
("smirnovci", "id->d"),
("smirnovi", "id->d"),
("smirnovp", "id->d"),
("tukeylambdacdf", "dd->d"),
("itstruve0", "d->d"),
("it2struve0", "d->d"),
("itmodstruve0", "d->d"),
("struve_h", "dd->d"),
("struve_l", "dd->d"),
("sinpi", "d->d"),
("sinpi", "D->D"),
("cospi", "d->d"),
("cospi", "D->D"),
("sindg", "d->d"),
("cosdg", "d->d"),
("tandg", "d->d"),
("cotdg", "d->d"),
("cosm1", "d->d"),
("radian", "ddd->d"),
("wright_bessel", "ddd->d"),
("log_wright_bessel", "ddd->d"),
("riemann_zeta", "d->d"),
("riemann_zeta", "D->D"),
("zetac", "d->d"),
("zeta", "dd->d"),
("zeta", "Dd->D"),
];
struct WrapperSpecCustom {
pattern: &'static str,
cpp: &'static str,
}
impl WrapperSpecCustom {
fn to_hpp(&self) -> String {
self.cpp
.lines()
.map(|line| line.trim_end())
.filter_map(|line| {
if !line.starts_with(' ') && line.ends_with(r") {") && line.contains('(') {
Some(format!("{};", &line[..line.len() - 2].trim_end()))
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}
}
const _CPP_COMPLEX_HELPERS: &str = r#"
cdouble complex__new(double re, double im) {
return cdouble(re, im);
}
void complex__values(cdouble z, double &re, double &im) {
re = std::real(z);
im = std::imag(z);
}"#;
const _CPP_RCT: &str = r#"
int rctj(size_t n, double x, double *rj, double *dj) {
int nm;
xsf::rctj(x, &nm, std::mdspan(rj, n + 1), std::mdspan(dj, n + 1));
return nm;
}
int rcty(size_t n, double x, double *ry, double *dy) {
int nm;
xsf::rcty(x, &nm, std::mdspan(ry, n + 1), std::mdspan(dy, n + 1));
return nm;
}"#;
const _CPP_CEVALPOLY: &str = r#"
cdouble cevalpoly(const double *coeffs, int degree, cdouble z) {
return xsf::cevalpoly(coeffs, degree, z);
}"#;
const _CPP_FCSZO: &str = r#"
void fcszo(int kf, int nt, cdouble *zo) {
xsf::fcszo(kf, nt, zo);
}"#;
const _CPP_KLVNZO: &str = r#"
void klvnzo(int nt, int kd, double *zo) {
xsf::klvnzo(nt, kd, zo);
}"#;
const _CPP_ASSOC_LEGENDRE_P: &str = r#"
double assoc_legendre_p_0(int n, int m, double z, int bc) {
return xsf::assoc_legendre_p(xsf::assoc_legendre_unnorm, n, m, z, bc);
}
double assoc_legendre_p_1(int n, int m, double z, int bc) {
return xsf::assoc_legendre_p(xsf::assoc_legendre_norm, n, m, z, bc);
}
cdouble assoc_legendre_p_0_1(int n, int m, cdouble z, int bc) {
return xsf::assoc_legendre_p(xsf::assoc_legendre_unnorm, n, m, z, bc);
}
cdouble assoc_legendre_p_1_1(int n, int m, cdouble z, int bc) {
return xsf::assoc_legendre_p(xsf::assoc_legendre_norm, n, m, z, bc);
}"#;
const _CPP_LEGENDRE_P_ALL: &str = r#"
void legendre_p_all(size_t n, double x, double *pn) {
xsf::legendre_p_all(x, std::mdspan(pn, n + 1));
}
void legendre_p_all_1(size_t n, cdouble z, cdouble *pn) {
xsf::legendre_p_all(z, std::mdspan(pn, n + 1));
}"#;
const _CPP_SPH_LEGENDRE_P_ALL: &str = r#"
void sph_legendre_p_all(size_t n, size_t m, double x, double *pnm) {
xsf::sph_legendre_p_all(x, std::mdspan(pnm, n + 1, 2 * m + 1));
}
void sph_legendre_p_all_1(size_t n, size_t m, cdouble z, cdouble *pnm) {
xsf::sph_legendre_p_all(z, std::mdspan(pnm, n + 1, 2 * m + 1));
}"#;
const _CPP_ASSOC_LEGENDRE_P_ALL: &str = r#"
void assoc_legendre_p_all_0(size_t n, size_t m, double z, int bc, double *pnm) {
auto res = std::mdspan(pnm, n + 1, 2 * m + 1);
xsf::assoc_legendre_p_all(xsf::assoc_legendre_unnorm, z, bc, res);
}
void assoc_legendre_p_all_0_1(size_t n, size_t m, cdouble z, int bc, cdouble *pnm) {
auto res = std::mdspan(pnm, n + 1, 2 * m + 1);
xsf::assoc_legendre_p_all(xsf::assoc_legendre_unnorm, z, bc, res);
}
void assoc_legendre_p_all_1(size_t n, size_t m, double z, int bc, double *pnm) {
auto res = std::mdspan(pnm, n + 1, 2 * m + 1);
xsf::assoc_legendre_p_all(xsf::assoc_legendre_norm, z, bc, res);
}
void assoc_legendre_p_all_1_1(size_t n, size_t m, cdouble z, int bc, cdouble *pnm) {
auto res = std::mdspan(pnm, n + 1, 2 * m + 1);
xsf::assoc_legendre_p_all(xsf::assoc_legendre_norm, z, bc, res);
}"#;
const _CPP_LQN: &str = r#"
void lqn(size_t n, double x, double *qn, double *qd) {
xsf::lqn(x, std::mdspan(qn, n + 1), std::mdspan(qd, n + 1));
}
void lqn_1(size_t n, cdouble z, cdouble *cqn, cdouble *cqd) {
xsf::lqn(z, std::mdspan(cqn, n + 1), std::mdspan(cqd, n + 1));
}"#;
const _CPP_LQMN: &str = r#"
void lqmn(size_t m, size_t n, double x, double *qm, double *qd) {
xsf::lqmn(x, std::mdspan(qm, m + 1, n + 1), std::mdspan(qd, m + 1, n + 1));
}
void lqmn_1(size_t m, size_t n, cdouble z, cdouble *qm, cdouble *qd) {
xsf::lqmn(z, std::mdspan(qm, m + 1, n + 1), std::mdspan(qd, m + 1, n + 1));
}"#;
const _CPP_SPH_HARM_Y_ALL: &str = r#"
void sph_harm_y_all(size_t n, size_t m, double theta, double phi, cdouble *res) {
xsf::sph_harm_y_all(theta, phi, std::mdspan(res, n + 1, 2 * m + 1));
}"#;
const WRAPPER_SPECS_CUSTOM: &[WrapperSpecCustom] = &[
WrapperSpecCustom {
pattern: r"complex__(new|values)",
cpp: _CPP_COMPLEX_HELPERS,
},
WrapperSpecCustom {
pattern: r"rct(j|y)",
cpp: _CPP_RCT,
},
WrapperSpecCustom {
pattern: r"cevalpoly",
cpp: _CPP_CEVALPOLY,
},
WrapperSpecCustom {
pattern: r"fcszo",
cpp: _CPP_FCSZO,
},
WrapperSpecCustom {
pattern: r"klvnzo",
cpp: _CPP_KLVNZO,
},
WrapperSpecCustom {
pattern: r"assoc_legendre_p_(0|1)",
cpp: _CPP_ASSOC_LEGENDRE_P,
},
WrapperSpecCustom {
pattern: r"legendre_p_all",
cpp: _CPP_LEGENDRE_P_ALL,
},
WrapperSpecCustom {
pattern: r"sph_legendre_p_all",
cpp: _CPP_SPH_LEGENDRE_P_ALL,
},
WrapperSpecCustom {
pattern: r"assoc_legendre_p_all_(0|1)",
cpp: _CPP_ASSOC_LEGENDRE_P_ALL,
},
WrapperSpecCustom {
pattern: r"lqn",
cpp: _CPP_LQN,
},
WrapperSpecCustom {
pattern: r"lqmn",
cpp: _CPP_LQMN,
},
WrapperSpecCustom {
pattern: r"sph_harm_y_all",
cpp: _CPP_SPH_HARM_Y_ALL,
},
];
fn get_ctype(code: char) -> &'static str {
match code {
'i' => "int",
'l' => "long",
'f' => "float",
'd' => "double",
'F' => "cfloat", 'D' => "cdouble", 'V' => "void",
_ => panic!("Unknown parameter type"),
}
}
fn split_typespec(spec: &str) -> (&str, &str) {
let parts: Vec<&str> = spec.split("->").collect();
assert!(parts.len() == 2);
(parts[0], parts[1])
}
fn fmt_return(spec: &str) -> &str {
let chars = split_typespec(spec).1;
if chars.len() > 1 {
"void"
} else {
get_ctype(chars.chars().next().unwrap())
}
}
fn fmt_params(spec: &str, types: bool, do_deref: bool) -> String {
let (inputs, outputs) = split_typespec(spec);
let mut params = inputs
.chars()
.map(get_ctype)
.enumerate()
.map(|(i, ct)| {
if types {
format!("{ct} x{i}")
} else {
format!("x{i}")
}
})
.collect::<Vec<_>>();
if outputs.len() > 1 {
let whence = if do_deref { '*' } else { '&' };
if types {
params.extend(
outputs
.chars()
.map(get_ctype)
.enumerate()
.map(|(i, ct)| format!("{ct} {whence}y{i}")),
);
} else {
params.extend((0..outputs.len()).map(|i| format!("y{i}")));
}
}
params.join(", ")
}
fn fmt_func(name: &str, spec: &str, suffix: &str) -> String {
let rtype = fmt_return(spec);
let params = fmt_params(spec, true, name.ends_with('*'));
let name = name.trim_end_matches('*');
let fname = if suffix.is_empty() {
name.to_string()
} else {
format!("{name}_{suffix}")
};
format!("{rtype} {fname}({params})")
}
fn fmt_call(name: &str, spec: &str) -> String {
let (inputs, outputs) = split_typespec(spec);
let clean_name = name.trim_end_matches('*');
let mut args = (0..inputs.len())
.map(|i| format!("x{i}"))
.collect::<Vec<String>>();
if outputs.len() > 1 {
args.extend((0..outputs.len()).map(|i| format!("y{i}")));
}
format!("xsf::{}({})", clean_name, args.join(", "))
}
fn push_line(source: &mut String, line: &str) {
source.push_str(line.trim());
source.push('\n');
}
fn write_file(path: String, content: String) -> String {
fs::write(&path, content).unwrap();
path
}
fn generate_hpp(dir_out: &str) -> String {
let mut source = String::from(WRAPPER_PREAMBLE);
push_line(&mut source, "#include <complex>");
push_line(&mut source, "#include <vector>");
push_line(&mut source, "");
push_line(&mut source, &format!("namespace {WRAPPER_NAME} {{"));
push_line(&mut source, "using cfloat = std::complex<float>;");
push_line(&mut source, "using cdouble = std::complex<double>;");
let mut name_counts = HashMap::new();
for (name, types) in WRAPPER_SPECS {
let count = name_counts.entry(*name).or_insert(0);
let suffix = if *count == 0 { "" } else { &count.to_string() };
let func_decl = fmt_func(name, types, suffix);
push_line(&mut source, &format!("{func_decl};"));
*count += 1;
}
for wrapper_extra in WRAPPER_SPECS_CUSTOM {
push_line(&mut source, &wrapper_extra.to_hpp());
}
push_line(&mut source, "}");
write_file(format!("{dir_out}/{WRAPPER_NAME}.hpp"), source)
}
fn generate_cpp(dir_out: &str) -> String {
let mut source = String::from(WRAPPER_PREAMBLE);
push_line(&mut source, "#define MDSPAN_USE_PAREN_OPERATOR 1");
push_line(
&mut source,
r#"#include "xsf/third_party/kokkos/mdspan.hpp""#,
);
push_line(&mut source, "");
push_line(&mut source, &format!(r#"#include "{WRAPPER_NAME}.hpp""#));
for xsf_header in WRAPPER_INCLUDES {
push_line(&mut source, &format!(r#"#include "xsf/{xsf_header}""#));
}
push_line(&mut source, "");
push_line(&mut source, &format!("namespace {WRAPPER_NAME} {{"));
push_line(&mut source, "");
let mut name_counts = HashMap::new();
for (name, types) in WRAPPER_SPECS {
let count = name_counts.entry(*name).or_insert(0);
let suffix = if *count == 0 { "" } else { &count.to_string() };
*count += 1;
let decl = fmt_func(name, types, suffix);
let call = fmt_call(name, types);
let stmt = if fmt_return(types) == "void" {
call.to_string()
} else {
format!("return {call}")
};
push_line(&mut source, &format!("{decl} {{ {stmt}; }}"));
}
for wrapper_extra in WRAPPER_SPECS_CUSTOM {
push_line(&mut source, wrapper_extra.cpp);
}
push_line(&mut source, "");
push_line(&mut source, "}");
write_file(format!("{dir_out}/{WRAPPER_NAME}.cpp"), source)
}
fn build_wrapper(dir_out: &str) {
let file_cpp = generate_cpp(dir_out);
let mut build = cc::Build::new();
build
.cpp(true)
.prefer_clang_cl_over_msvc(true)
.flag_if_supported("-Wno-unused-parameter")
.flag_if_supported("-Wno-logical-op-parentheses")
.include(XSF_INCLUDE)
.file(file_cpp);
if build.get_compiler().is_like_msvc() {
build.flag(format!("/std:{CXX_STANDARD}"));
} else {
build.std(CXX_STANDARD);
}
build.compile(WRAPPER_NAME)
}
fn get_allowlist() -> String {
fn format_entry(name: &str) -> String {
format!(r"{WRAPPER_NAME}::{name}(_\d)?")
}
let mut entries = WRAPPER_SPECS
.iter()
.map(|(name, _)| format_entry(name.trim_end_matches('*')))
.chain(
WRAPPER_SPECS_CUSTOM
.iter()
.map(|wx| format_entry(wx.pattern)),
)
.collect::<Vec<_>>();
entries.dedup();
entries.join("|")
}
fn generate_bindings(dir_out: &str, header: &str) {
bindgen::Builder::default()
.header(header)
.allowlist_function(get_allowlist())
.clang_args(["-x", "c++"])
.enable_cxx_namespaces()
.dynamic_link_require_all(true)
.size_t_is_usize(true)
.sort_semantically(true)
.derive_copy(false)
.use_core()
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate()
.unwrap()
.write_to_file(PathBuf::from(dir_out).join("bindings.rs"))
.unwrap();
}
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let header = generate_hpp(&out_dir);
println!("cargo:rerun-if-changed={XSF_INCLUDE}");
build_wrapper(&out_dir);
generate_bindings(&out_dir, &header);
}