diffsl 0.11.9

A compiler for a domain-specific language for ordinary differential equations (ODE).
Documentation
use std::{
    path::PathBuf,
    process::Command,
};

#[allow(unused_imports)]
use diffsl::execution::compiler::CompilerMode;
#[allow(unused_imports)]
use diffsl::{Compiler, ExternalDynModule};

pub fn build_fixture_library(fixture_name: &str) -> PathBuf {
    let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
    let manifest_path = manifest_dir
        .join("tests")
        .join("fixtures")
        .join(fixture_name)
        .join("Cargo.toml");
    let target_dir = manifest_dir
        .join("target")
        .join("external_dynamic_fixtures")
        .join(fixture_name);

    let status = Command::new(std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_string()))
        .arg("build")
        .arg("--manifest-path")
        .arg(&manifest_path)
        .arg("--quiet")
        .env("CARGO_TARGET_DIR", &target_dir)
        .status()
        .expect("fixture crate should build");
    assert!(status.success(), "fixture crate build failed");

    let library_path = target_dir.join("debug").join(dynamic_library_name(fixture_name));
    assert!(
        library_path.exists(),
        "fixture library was not produced at {}",
        library_path.display()
    );
    library_path
}

fn dynamic_library_name(crate_name: &str) -> String {
    if cfg!(target_os = "windows") {
        format!("{crate_name}.dll")
    } else if cfg!(target_os = "macos") {
        format!("lib{crate_name}.dylib")
    } else {
        format!("lib{crate_name}.so")
    }
}

#[allow(unused_macros)]
macro_rules! define_external_dynamic_test {
    ($ty:ty, $test_name:ident, $fixture_name:literal) => {
        #[test]
        fn $test_name() {
            const STATES: usize = 1;
            const INPUTS: usize = 1;
            const OUTPUTS: usize = 1;
            const DATA: usize = 1;
            const STOP: usize = 1;

            let fixture_path = build_fixture_library($fixture_name);
            let module = ExternalDynModule::<$ty>::new(&fixture_path)
                .expect("dynamic module should load fixture library");
            let compiler = Compiler::from_codegen_module(module, CompilerMode::SingleThreaded)
                .expect("compiler should build");

            let (n_states, n_inputs, n_outputs, n_data, n_stop, has_mass, has_reset) =
                compiler.get_dims();
            assert_eq!(n_states, STATES);
            assert_eq!(n_inputs, INPUTS);
            assert_eq!(n_outputs, OUTPUTS);
            assert_eq!(n_data, DATA);
            assert_eq!(n_stop, STOP);
            assert!(has_mass);
            assert!(has_reset);

            let mut data = vec![-1.0 as $ty; n_data];
            let inputs = vec![1.0 as $ty; n_inputs];
            compiler.set_inputs(&inputs, &mut data, 0);

            let mut tensors = compiler.get_tensors();
            tensors.sort();
            assert_eq!(tensors, vec!["in".to_string()]);
            assert_eq!(
                compiler.get_tensor_data("in", &data).expect("tensor export should resolve"),
                inputs.as_slice()
            );

            let mut constants = compiler.get_constants();
            constants.sort();
            assert_eq!(constants, vec!["rate".to_string()]);
            assert_eq!(
                compiler
                    .get_constants_data("rate")
                    .expect("constant export should resolve"),
                [3.0 as $ty].as_slice()
            );

            let mut inputs_out = vec![-2.0 as $ty; n_inputs];
            compiler.get_inputs(&mut inputs_out, &data);
            assert_eq!(inputs_out, inputs);

            let mut u = vec![-2.0 as $ty; n_states];
            compiler.set_u0(&mut u, &mut data);
            assert_eq!(u[0], 1.0 as $ty);

            let mut out = vec![-3.0 as $ty; n_outputs];
            compiler.calc_out(0.0 as $ty, &u, &mut data, &mut out);
            assert_eq!(out[0], u[0]);

            let mut rr = vec![-4.0 as $ty; n_states];
            compiler.rhs(0.0 as $ty, &u, &mut data, &mut rr);
            assert_eq!(rr[0], 0.0 as $ty);

            let mut stop = vec![-5.0 as $ty; n_stop];
            compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop);
            assert_eq!(stop[0], 0.5 as $ty);

            let du_stop = vec![1.0 as $ty; n_states];
            let mut ddata_stop = vec![-5.15 as $ty; n_data];
            let mut dstop = vec![-5.25 as $ty; n_stop];
            compiler.calc_stop_grad(
                0.0 as $ty,
                &u,
                &du_stop,
                &data,
                &mut ddata_stop,
                &stop,
                &mut dstop,
            );
            assert_eq!(dstop[0], 1.0 as $ty);

            let mut du_stop_rev = vec![-5.35 as $ty; n_states];
            let mut ddata_stop_rev = vec![-5.45 as $ty; n_data];
            let mut dstop_rev = vec![1.0 as $ty; n_stop];
            compiler.calc_stop_rgrad(
                0.0 as $ty,
                &u,
                &mut du_stop_rev,
                &data,
                &mut ddata_stop_rev,
                &stop,
                &mut dstop_rev,
            );
            assert!((du_stop_rev[0] - (-4.35 as $ty)).abs() < (1e-6 as $ty));

            let mut ddata_stop_s = vec![-5.55 as $ty; n_data];
            let mut dstop_s = vec![-5.65 as $ty; n_stop];
            compiler.calc_stop_sgrad(
                0.0 as $ty,
                &u,
                &data,
                &mut ddata_stop_s,
                &stop,
                &mut dstop_s,
            );
            assert_eq!(dstop_s[0], 0.0 as $ty);

            let mut ddata_stop_sr = vec![-5.75 as $ty; n_data];
            let mut dstop_sr = vec![1.0 as $ty; n_stop];
            compiler.calc_stop_srgrad(
                0.0 as $ty,
                &u,
                &data,
                &mut ddata_stop_sr,
                &stop,
                &mut dstop_sr,
            );
            assert_eq!(dstop_sr[0], 0.0 as $ty);

            let mut reset = vec![-5.5 as $ty; n_states];
            compiler.reset(0.0 as $ty, &u, &mut data, &mut reset);
            assert_eq!(reset[0], 2.0 as $ty);

            let du = vec![1.0 as $ty; n_states];
            let mut ddata = vec![-8.0 as $ty; n_data];
            let mut dreset = vec![-5.75 as $ty; n_states];
            compiler.reset_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &reset, &mut dreset);
            assert_eq!(dreset[0], 2.0 as $ty);

            let mut du_reset_rev = vec![-5.85 as $ty; n_states];
            let mut ddata_reset_rev = vec![-5.95 as $ty; n_data];
            let mut dreset_rev = vec![1.0 as $ty; n_states];
            compiler.reset_rgrad(
                0.0 as $ty,
                &u,
                &mut du_reset_rev,
                &data,
                &mut ddata_reset_rev,
                &reset,
                &mut dreset_rev,
            );
            assert!((du_reset_rev[0] - (-3.85 as $ty)).abs() < (1e-6 as $ty));

            let mut ddata_reset_s = vec![-6.05 as $ty; n_data];
            let mut dreset_s = vec![-6.15 as $ty; n_states];
            compiler.reset_sgrad(
                0.0 as $ty,
                &u,
                &data,
                &mut ddata_reset_s,
                &reset,
                &mut dreset_s,
            );
            assert_eq!(dreset_s[0], 0.0 as $ty);

            let mut ddata_reset_sr = vec![-6.25 as $ty; n_data];
            let mut dreset_sr = vec![1.0 as $ty; n_states];
            compiler.reset_srgrad(
                0.0 as $ty,
                &u,
                &data,
                &mut ddata_reset_sr,
                &reset,
                &mut dreset_sr,
            );
            assert_eq!(dreset_sr[0], 0.0 as $ty);

            let mut mv = vec![-6.0 as $ty; n_states];
            compiler.mass(0.0 as $ty, &u, &mut data, &mut mv);
            assert_eq!(mv[0], 1.0 as $ty);

            let mut id = vec![-7.0 as $ty; n_states];
            compiler.set_id(&mut id);
            assert_eq!(id[0], 42.0 as $ty);

            let mut drr = vec![-9.0 as $ty; n_states];
            compiler.rhs_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &rr, &mut drr);
            assert_eq!(drr[0], -1.0 as $ty);
            assert_eq!(ddata[0], 0.0 as $ty);

            let mut dout = vec![-10.0 as $ty; n_outputs];
            compiler.calc_out_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &out, &mut dout);
            assert_eq!(dout[0], 1.0 as $ty);
            assert_eq!(ddata[0], 0.0 as $ty);

            let mut dinputs = vec![1.0 as $ty; n_inputs];
            compiler.set_inputs_grad(&inputs, &dinputs, &data, &mut ddata, 0);
            assert_eq!(ddata[0], 1.0 as $ty);

            let mut du_rev = vec![-11.0 as $ty; n_states];
            let mut ddata_rev = vec![-12.0 as $ty; n_data];
            let mut drr_rev = vec![1.0 as $ty; n_states];
            compiler.rhs_rgrad(
                0.0 as $ty,
                &u,
                &mut du_rev,
                &data,
                &mut ddata_rev,
                &rr,
                &mut drr_rev,
            );
            assert_eq!(du_rev[0], -12.0 as $ty);
            assert_eq!(ddata_rev[0], -12.0 as $ty);

            let mut dv = vec![-13.0 as $ty; n_states];
            let mut dmv = vec![1.0 as $ty; n_states];
            compiler.mass_rgrad(0.0 as $ty, &mut dv, &data, &mut ddata_rev, &mut dmv);
            assert_eq!(dv[0], -12.0 as $ty);

            let mut dout_rev = vec![1.0 as $ty; n_outputs];
            compiler.calc_out_rgrad(
                0.0 as $ty,
                &u,
                &mut du_rev,
                &data,
                &mut ddata_rev,
                &out,
                &mut dout_rev,
            );
            assert_eq!(du_rev[0], -11.0 as $ty);

            compiler.set_inputs_rgrad(&inputs, &mut dinputs, &data, &mut ddata_rev, 0);
            assert_eq!(dinputs[0], -11.0 as $ty);

            let mut ddata_s = vec![-14.0 as $ty; n_data];
            let mut drr_s = vec![-15.0 as $ty; n_states];
            compiler.rhs_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &rr, &mut drr_s);
            assert_eq!(drr_s[0], 0.0 as $ty);
            assert_eq!(ddata_s[0], 0.0 as $ty);

            let mut dout_s = vec![-16.0 as $ty; n_outputs];
            compiler.calc_out_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &out, &mut dout_s);
            assert_eq!(dout_s[0], 0.0 as $ty);

            let mut ddata_sr = vec![-17.0 as $ty; n_data];
            let mut drr_sr = vec![-18.0 as $ty; n_states];
            compiler.rhs_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &rr, &mut drr_sr);
            assert_eq!(drr_sr[0], 0.0 as $ty);
            assert_eq!(ddata_sr[0], 0.0 as $ty);

            let mut dout_sr = vec![-19.0 as $ty; n_outputs];
            compiler.calc_out_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &out, &mut dout_sr);
            assert_eq!(dout_sr[0], 0.0 as $ty);
        }
    };
}