bempp 0.2.0

Boundary element method library.
"""
Generate a Rust file containing the quadrature rules
generated by polyquad.
"""

import os
import numpy as np

all_rule_files = []
orders = []
npoints = []

for dirpath, dirnames, filenames in os.walk("."):
    all_rule_files += [os.path.join(dirpath, file) for file in filenames if file.endswith(".txt")]

for rule_file in all_rule_files:
    base = os.path.basename(rule_file)
    order_str, points_str = os.path.splitext(base)[0].split("-")
    orders += [int(order_str)]
    npoints += [int(points_str)]

with open("simplex_rule_definitions.rs", "w") as f:
    f.write("//! Definition of simplex rules.\n")
    f.write("\n")
    f.write("use std::collections::HashMap;\n")
    f.write("use bempp_traits::cell::ReferenceCellType;\n")
    f.write("\n")
    f.write("type HM = HashMap<usize, (usize, Vec<f64>, Vec<f64>)>;\n")
    f.write("\n")
    f.write("lazy_static! {\n")
    f.write("pub(crate) static ref SIMPLEX_RULE_DEFINITIONS: HashMap<ReferenceCellType, HM> = {\n")
    f.write("let mut m = HashMap::<ReferenceCellType, HM>::new();\n")
    f.write("m.insert(ReferenceCellType::Triangle, HM::new());\n")
    f.write("m.insert(ReferenceCellType::Quadrilateral, HM::new());\n")
    f.write("m.insert(ReferenceCellType::Hexahedron, HM::new());\n")
    f.write("m.insert(ReferenceCellType::Tetrahedron, HM::new());\n")
    f.write("m.insert(ReferenceCellType::Prism, HM::new());\n")
    f.write("m.insert(ReferenceCellType::Pyramid, HM::new());\n")
    f.write("m.insert(ReferenceCellType::Interval, HM::new());\n")

    for index, rule_file in enumerate(all_rule_files):
        arr = np.atleast_2d(np.loadtxt(rule_file))
        points = arr[:, :-1]
        weights = arr[:, -1]

        identifier = None
        # Get identifier and reparameterize for
        # from polyquad reference element to our
        # reference element
        if rule_file.startswith("./quad"):
            identifier = "ReferenceCellType::Quadrilateral"

            points = 0.5 * (1.0 + points)
            weights = weights / 4.0

        elif rule_file.startswith("./tri"):
            identifier = "ReferenceCellType::Triangle"

            points = 0.5 * (1.0 + points)
            weights = weights / 4.0

        elif rule_file.startswith("./hex"):
            identifier = "ReferenceCellType::Hexahedron"

            points = 0.5 * (1.0 + points)
            weights = weights / 8.0

        elif rule_file.startswith("./pri"):
            identifier = "ReferenceCellType::Prism"

            points = 0.5 * (1.0 + points)
            weights = weights / 8.0

        elif rule_file.startswith("./tet"):
            identifier = "ReferenceCellType::Tetrahedron"

            points = 0.5 * (1.0 + points)
            weights = weights / 8.0

        elif rule_file.startswith("./pyr"):
            identifier = "ReferenceCellType::Pyramid"

            points = (1.0 + points) @ np.array(
                [[0.5, 0, 0], [0, 0.5, 0], [-0.25, -0.25, 0.5]], dtype="float64"
            )

            weights = weights / 8.0

        else:
            raise ValueError("Unknown simplex type.")

        points = points.T.flatten()
        weights = weights.flatten()

        f.write("m.get_mut(&" + identifier + ").unwrap().insert(\n")

        f.write(str(npoints[index]) + ", \n")
        f.write("(" + str(orders[index]) + ",vec![")
        for point in points:
            f.write(f"{point},")
        f.write("],\n")
        f.write("vec![\n")
        for weight in weights:
            f.write(f"{weight},")
        f.write("]));\n")

    # Now add the standard Gauss Legendre rules

    nmax = 100
    identifier = "ReferenceCellType::Interval"

    for n in range(1, nmax + 1):
        p, w = np.polynomial.legendre.leggauss(n)
        sorted_indices = np.argsort(p)
        points = 0.5 * (1.0 + p[sorted_indices])
        weights = 0.5 * w[sorted_indices]

        f.write("m.get_mut(&" + identifier + ").unwrap().insert(\n")
        f.write(str(len(w)) + ", \n")
        f.write("(" + str(2 * len(w) - 1) + ",vec![")
        for point in points:
            f.write(f"{point},")
        f.write("],\n")
        f.write("vec![\n")
        for weight in weights:
            f.write(f"{weight},")
        f.write("]));\n")

    f.write("m };\n}")

os.system("rustfmt ./simplex_rule_definitions.rs")
os.system("cp ./simplex_rule_definitions.rs ../")
os.system("rm ./simplex_rule_definitions.rs")