diffsl 0.9.4

A compiler for a domain-specific language for ordinary differential equations (ODE).
Documentation
use std::{collections::HashMap, marker::PhantomData};

use anyhow::Result;

use super::module::{CodegenModule, CodegenModuleJit};

type UIntType = u32;

macro_rules! define_symbol_module {
    ($mod_name:ident, $ty:ty) => {
        mod $mod_name {
            use super::UIntType;

            #[allow(clashing_extern_declarations)]
            extern "C" {
                #[link_name = "barrier_init"]
                pub fn barrier_init();
                #[link_name = "set_constants"]
                pub fn set_constants(thread_id: UIntType, thread_dim: UIntType);
                #[link_name = "set_u0"]
                pub fn set_u0(
                    u: *mut $ty,
                    data: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs"]
                pub fn rhs(
                    time: $ty,
                    u: *const $ty,
                    data: *mut $ty,
                    rr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs_grad"]
                pub fn rhs_grad(
                    time: $ty,
                    u: *const $ty,
                    du: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    rr: *const $ty,
                    drr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs_rgrad"]
                pub fn rhs_rgrad(
                    time: $ty,
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    rr: *const $ty,
                    drr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs_sgrad"]
                pub fn rhs_sgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    rr: *const $ty,
                    drr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs_srgrad"]
                pub fn rhs_srgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    rr: *const $ty,
                    drr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "mass"]
                pub fn mass(
                    time: $ty,
                    v: *const $ty,
                    data: *mut $ty,
                    mv: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "mass_rgrad"]
                pub fn mass_rgrad(
                    time: $ty,
                    v: *const $ty,
                    dv: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    mv: *const $ty,
                    dmv: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "set_u0_grad"]
                pub fn set_u0_grad(
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "set_u0_rgrad"]
                pub fn set_u0_rgrad(
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "set_u0_sgrad"]
                pub fn set_u0_sgrad(
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out"]
                pub fn calc_out(
                    time: $ty,
                    u: *const $ty,
                    data: *mut $ty,
                    out: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out_grad"]
                pub fn calc_out_grad(
                    time: $ty,
                    u: *const $ty,
                    du: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    out: *const $ty,
                    dout: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out_rgrad"]
                pub fn calc_out_rgrad(
                    time: $ty,
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    out: *const $ty,
                    dout: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out_sgrad"]
                pub fn calc_out_sgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    out: *const $ty,
                    dout: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out_srgrad"]
                pub fn calc_out_srgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    out: *const $ty,
                    dout: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_stop"]
                pub fn calc_stop(
                    time: $ty,
                    u: *const $ty,
                    data: *mut $ty,
                    root: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "set_id"]
                pub fn set_id(id: *mut $ty);
                #[link_name = "get_dims"]
                pub fn get_dims(
                    states: *mut UIntType,
                    inputs: *mut UIntType,
                    outputs: *mut UIntType,
                    data: *mut UIntType,
                    stop: *mut UIntType,
                    has_mass: *mut UIntType,
                );
                #[link_name = "set_inputs"]
                pub fn set_inputs(inputs: *const $ty, data: *mut $ty);
                #[link_name = "get_inputs"]
                pub fn get_inputs(inputs: *mut $ty, data: *const $ty);
                #[link_name = "set_inputs_grad"]
                pub fn set_inputs_grad(
                    inputs: *const $ty,
                    dinputs: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                );
                #[link_name = "set_inputs_rgrad"]
                pub fn set_inputs_rgrad(
                    inputs: *const $ty,
                    dinputs: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                );
            }
        }
    };
}

#[cfg(feature = "external_f32")]
define_symbol_module!(f32_symbols, f32);
#[cfg(feature = "external_f64")]
define_symbol_module!(f64_symbols, f64);

pub struct ExternalModule<T> {
    _marker: PhantomData<T>,
}

impl<T> ExternalModule<T> {
    pub fn new() -> Self {
        Self {
            _marker: PhantomData,
        }
    }
}

impl<T> Default for ExternalModule<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T> CodegenModule for ExternalModule<T> where T: Send + Sync + 'static {}

pub trait ExternSymbols {
    fn insert_symbols(symbols: &mut HashMap<String, *const u8>);
}

impl<T> CodegenModuleJit for ExternalModule<T>
where
    T: ExternSymbols + Send + Sync + 'static,
{
    fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
        let mut symbols = HashMap::new();
        T::insert_symbols(&mut symbols);
        Ok(symbols)
    }
}

macro_rules! impl_extern_symbols {
    ($ty:ty, $sym:path, { $($name:literal => $func:ident),+ $(,)? }) => {
        impl ExternSymbols for $ty {
            fn insert_symbols(symbols: &mut HashMap<String, *const u8>) {
                use $sym as sym;
                $(symbols.insert($name.to_string(), sym::$func as *const u8);)+
            }
        }
    };
}

#[cfg(feature = "external_f64")]
impl_extern_symbols!(f64, f64_symbols, {
    "barrier_init" => barrier_init,
    "set_constants" => set_constants,
    "set_u0" => set_u0,
    "rhs" => rhs,
    "rhs_grad" => rhs_grad,
    "rhs_rgrad" => rhs_rgrad,
    "rhs_sgrad" => rhs_sgrad,
    "rhs_srgrad" => rhs_srgrad,
    "mass" => mass,
    "mass_rgrad" => mass_rgrad,
    "set_u0_grad" => set_u0_grad,
    "set_u0_rgrad" => set_u0_rgrad,
    "set_u0_sgrad" => set_u0_sgrad,
    "calc_out" => calc_out,
    "calc_out_grad" => calc_out_grad,
    "calc_out_rgrad" => calc_out_rgrad,
    "calc_out_sgrad" => calc_out_sgrad,
    "calc_out_srgrad" => calc_out_srgrad,
    "calc_stop" => calc_stop,
    "set_id" => set_id,
    "get_dims" => get_dims,
    "set_inputs" => set_inputs,
    "get_inputs" => get_inputs,
    "set_inputs_grad" => set_inputs_grad,
    "set_inputs_rgrad" => set_inputs_rgrad,
});

#[cfg(feature = "external_f32")]
impl_extern_symbols!(f32, f32_symbols, {
    "barrier_init" => barrier_init,
    "set_constants" => set_constants,
    "set_u0" => set_u0,
    "rhs" => rhs,
    "rhs_grad" => rhs_grad,
    "rhs_rgrad" => rhs_rgrad,
    "rhs_sgrad" => rhs_sgrad,
    "rhs_srgrad" => rhs_srgrad,
    "mass" => mass,
    "mass_rgrad" => mass_rgrad,
    "set_u0_grad" => set_u0_grad,
    "set_u0_rgrad" => set_u0_rgrad,
    "set_u0_sgrad" => set_u0_sgrad,
    "calc_out" => calc_out,
    "calc_out_grad" => calc_out_grad,
    "calc_out_rgrad" => calc_out_rgrad,
    "calc_out_sgrad" => calc_out_sgrad,
    "calc_out_srgrad" => calc_out_srgrad,
    "calc_stop" => calc_stop,
    "set_id" => set_id,
    "get_dims" => get_dims,
    "set_inputs" => set_inputs,
    "get_inputs" => get_inputs,
    "set_inputs_grad" => set_inputs_grad,
    "set_inputs_rgrad" => set_inputs_rgrad,
});