diffsl 0.11.1

A compiler for a domain-specific language for ordinary differential equations (ODE).
Documentation
use std::{collections::HashMap, marker::PhantomData};

use anyhow::Result;

use super::module::{CodegenModule, CodegenModuleJit};

type UIntType = u32;

macro_rules! define_symbol_module {
    ($mod_name:ident, $ty:ty) => {
        mod $mod_name {
            use super::UIntType;

            #[allow(clashing_extern_declarations)]
            extern "C" {
                #[link_name = "barrier_init"]
                pub fn barrier_init();
                #[link_name = "set_constants"]
                pub fn set_constants(thread_id: UIntType, thread_dim: UIntType);
                #[link_name = "set_u0"]
                pub fn set_u0(
                    u: *mut $ty,
                    data: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "reset"]
                pub fn reset(
                    time: $ty,
                    u: *const $ty,
                    data: *mut $ty,
                    reset: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "reset_grad"]
                pub fn reset_grad(
                    time: $ty,
                    u: *const $ty,
                    du: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    reset: *const $ty,
                    dreset: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "reset_rgrad"]
                pub fn reset_rgrad(
                    time: $ty,
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    reset: *const $ty,
                    dreset: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "reset_sgrad"]
                pub fn reset_sgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    reset: *const $ty,
                    dreset: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "reset_srgrad"]
                pub fn reset_srgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    reset: *const $ty,
                    dreset: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs"]
                pub fn rhs(
                    time: $ty,
                    u: *const $ty,
                    data: *mut $ty,
                    rr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs_grad"]
                pub fn rhs_grad(
                    time: $ty,
                    u: *const $ty,
                    du: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    rr: *const $ty,
                    drr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs_rgrad"]
                pub fn rhs_rgrad(
                    time: $ty,
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    rr: *const $ty,
                    drr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs_sgrad"]
                pub fn rhs_sgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    rr: *const $ty,
                    drr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "rhs_srgrad"]
                pub fn rhs_srgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    rr: *const $ty,
                    drr: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "mass"]
                pub fn mass(
                    time: $ty,
                    v: *const $ty,
                    data: *mut $ty,
                    mv: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "mass_rgrad"]
                pub fn mass_rgrad(
                    time: $ty,
                    v: *const $ty,
                    dv: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    mv: *const $ty,
                    dmv: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "set_u0_grad"]
                pub fn set_u0_grad(
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "set_u0_rgrad"]
                pub fn set_u0_rgrad(
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "set_u0_sgrad"]
                pub fn set_u0_sgrad(
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out"]
                pub fn calc_out(
                    time: $ty,
                    u: *const $ty,
                    data: *mut $ty,
                    out: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out_grad"]
                pub fn calc_out_grad(
                    time: $ty,
                    u: *const $ty,
                    du: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    out: *const $ty,
                    dout: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out_rgrad"]
                pub fn calc_out_rgrad(
                    time: $ty,
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    out: *const $ty,
                    dout: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out_sgrad"]
                pub fn calc_out_sgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    out: *const $ty,
                    dout: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_out_srgrad"]
                pub fn calc_out_srgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    out: *const $ty,
                    dout: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_stop"]
                pub fn calc_stop(
                    time: $ty,
                    u: *const $ty,
                    data: *mut $ty,
                    root: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_stop_grad"]
                pub fn calc_stop_grad(
                    time: $ty,
                    u: *const $ty,
                    du: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    root: *const $ty,
                    droot: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_stop_rgrad"]
                pub fn calc_stop_rgrad(
                    time: $ty,
                    u: *const $ty,
                    du: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    root: *const $ty,
                    droot: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_stop_sgrad"]
                pub fn calc_stop_sgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    root: *const $ty,
                    droot: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "calc_stop_srgrad"]
                pub fn calc_stop_srgrad(
                    time: $ty,
                    u: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    root: *const $ty,
                    droot: *mut $ty,
                    thread_id: UIntType,
                    thread_dim: UIntType,
                );
                #[link_name = "set_id"]
                pub fn set_id(id: *mut $ty);
                #[link_name = "get_dims"]
                pub fn get_dims(
                    states: *mut UIntType,
                    inputs: *mut UIntType,
                    outputs: *mut UIntType,
                    data: *mut UIntType,
                    stop: *mut UIntType,
                    has_mass: *mut UIntType,
                    has_reset: *mut UIntType,
                );
                #[link_name = "set_inputs"]
                pub fn set_inputs(inputs: *const $ty, data: *mut $ty, model_index: UIntType);
                #[link_name = "get_inputs"]
                pub fn get_inputs(inputs: *mut $ty, data: *const $ty);
                #[link_name = "set_inputs_grad"]
                pub fn set_inputs_grad(
                    inputs: *const $ty,
                    dinputs: *const $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    model_index: UIntType,
                );
                #[link_name = "set_inputs_rgrad"]
                pub fn set_inputs_rgrad(
                    inputs: *const $ty,
                    dinputs: *mut $ty,
                    data: *const $ty,
                    ddata: *mut $ty,
                    model_index: UIntType,
                );
            }
        }
    };
}

#[cfg(feature = "external_f32")]
define_symbol_module!(f32_symbols, f32);
#[cfg(feature = "external_f64")]
define_symbol_module!(f64_symbols, f64);

pub struct ExternalModule<T> {
    _marker: PhantomData<T>,
}

impl<T> ExternalModule<T> {
    pub fn new() -> Self {
        Self {
            _marker: PhantomData,
        }
    }
}

impl<T> Default for ExternalModule<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T> CodegenModule for ExternalModule<T> where T: Send + Sync + 'static {}

pub trait ExternSymbols {
    fn insert_symbols(symbols: &mut HashMap<String, *const u8>);
}

impl<T> CodegenModuleJit for ExternalModule<T>
where
    T: ExternSymbols + Send + Sync + 'static,
{
    fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
        let mut symbols = HashMap::new();
        T::insert_symbols(&mut symbols);
        Ok(symbols)
    }
}

macro_rules! impl_extern_symbols {
    ($ty:ty, $sym:path) => {
        impl ExternSymbols for $ty {
            fn insert_symbols(symbols: &mut HashMap<String, *const u8>) {
                crate::execution::external_interface::insert_external_symbols!(symbols, $sym);
            }
        }
    };
}

#[cfg(feature = "external_f64")]
impl_extern_symbols!(f64, f64_symbols);

#[cfg(feature = "external_f32")]
impl_extern_symbols!(f32, f32_symbols);