use core::panic;
use std::{cmp::min, collections::HashMap, num::NonZero};
use super::{
interface::{
JitFunctions, JitGetTensorFunctions, JitGradFunctions, JitGradRFunctions,
JitSensGradFunctions, JitSensRevGradFunctions,
},
module::{CodegenModule, CodegenModuleCompile, CodegenModuleJit, CodegenModuleLink},
};
use crate::{discretise::DiscreteModel, execution::scalar::Scalar, parser::parse_ds_string};
use anyhow::{anyhow, Result};
#[cfg(feature = "rayon")]
use rayon::{ThreadPool, ThreadPoolBuilder};
use uid::Id;
pub struct Compiler<M: CodegenModule, T: Scalar> {
jit_functions: JitFunctions<T>,
jit_grad_functions: JitGradFunctions<T>,
jit_grad_r_functions: Option<JitGradRFunctions<T>>,
jit_sens_grad_functions: Option<JitSensGradFunctions<T>>,
jit_sens_rev_grad_functions: Option<JitSensRevGradFunctions<T>>,
jit_get_tensor_functions: JitGetTensorFunctions<T>,
number_of_states: usize,
number_of_parameters: usize,
number_of_outputs: usize,
number_of_stop: usize,
data_size: usize,
has_mass: bool,
has_reset: bool,
#[cfg(feature = "rayon")]
thread_pool: Option<ThreadPool>,
#[cfg(not(feature = "rayon"))]
thread_pool: Option<()>,
thread_lock: Option<std::sync::Mutex<()>>,
_module: M,
}
#[derive(Default, Clone, Copy)]
pub enum CompilerMode {
MultiThreaded(Option<usize>),
#[default]
SingleThreaded,
}
impl CompilerMode {
/// number of threads to use
/// prefer the number of threads specified by the user (RAYON_NUM_THREADS)
/// if not specified, use the number of available threads
/// don't use more threads than the number of states
pub fn thread_dim(&self, number_of_states: usize) -> usize {
match self {
CompilerMode::MultiThreaded(Some(n)) => min(*n, number_of_states),
CompilerMode::MultiThreaded(None) => {
let num_cpus = std::thread::available_parallelism()
.unwrap_or(NonZero::new(1).unwrap())
.get();
let thread_dim = std::env::var("RAYON_NUM_THREADS")
.unwrap_or_else(|_| num_cpus.to_string())
.parse::<usize>()
.unwrap()
.max(1);
let max_threads = (number_of_states / 10).max(1);
min(thread_dim, max_threads)
}
_ => 1,
}
}
}
#[derive(Default)]
pub struct CompilerOptions {
pub mode: CompilerMode,
pub debug: bool,
}
impl<M: CodegenModule, T: Scalar> Compiler<M, T> {
pub fn new(
module: M,
symbol_map: HashMap<String, *const u8>,
mode: CompilerMode,
) -> Result<Self> {
let jit_functions = JitFunctions::<T>::new(&symbol_map)?;
let jit_grad_functions = JitGradFunctions::<T>::new(&symbol_map)?;
let jit_get_tensor_functions = JitGetTensorFunctions::<T>::new(&symbol_map)?;
let jit_grad_r_functions = JitGradRFunctions::<T>::new(&symbol_map).ok();
let jit_sens_grad_functions = JitSensGradFunctions::<T>::new(&symbol_map).ok();
let jit_sens_rev_grad_functions = JitSensRevGradFunctions::<T>::new(&symbol_map).ok();
let mut ret = Self {
jit_functions,
jit_grad_functions,
jit_grad_r_functions,
jit_sens_grad_functions,
jit_sens_rev_grad_functions,
jit_get_tensor_functions,
number_of_states: 0,
number_of_parameters: 0,
number_of_outputs: 0,
number_of_stop: 0,
data_size: 0,
has_mass: false,
has_reset: false,
thread_pool: None,
thread_lock: None,
_module: module,
};
let (
number_of_states,
number_of_parameters,
number_of_outputs,
data_size,
number_of_stops,
has_mass,
has_reset,
) = ret.get_dims();
ret.number_of_states = number_of_states;
ret.number_of_parameters = number_of_parameters;
ret.number_of_outputs = number_of_outputs;
ret.number_of_stop = number_of_stops;
ret.has_mass = has_mass;
ret.has_reset = has_reset;
ret.data_size = data_size;
let thread_dim = mode.thread_dim(number_of_states);
#[cfg(feature = "rayon")]
let (thread_pool, thread_lock) = if thread_dim > 1 {
(
Some(ThreadPoolBuilder::new().num_threads(thread_dim).build()?),
Some(std::sync::Mutex::new(())),
)
} else {
(None, None)
};
#[cfg(not(feature = "rayon"))]
let (thread_pool, thread_lock) = if thread_dim > 1 {
return Err(anyhow!(
"Threading is not supported in this build, please enable the 'rayon' feature"
));
} else {
(None, None)
};
ret.thread_pool = thread_pool;
ret.thread_lock = thread_lock;
// all done, can set constants now
ret.set_constants();
Ok(ret)
}
pub fn from_codegen_module(mut module: M, mode: CompilerMode) -> Result<Self>
where
M: CodegenModuleJit,
{
let symbol_map = module.jit()?;
Self::new(module, symbol_map, mode)
}
pub fn from_object_file(buffer: Vec<u8>, mode: CompilerMode) -> Result<Self>
where
M: CodegenModuleLink + CodegenModuleJit,
{
let mut module = M::from_object(buffer.as_slice())?;
let symbol_map = module.jit()?;
Self::new(module, symbol_map, mode)
}
pub fn from_discrete_str(code: &str, options: CompilerOptions) -> Result<Self>
where
M: CodegenModuleCompile + CodegenModuleJit,
{
let uid = Id::<u32>::new();
let name = format!("diffsl_{uid}");
let model = parse_ds_string(code).map_err(|e| anyhow!(e.to_string()))?;
let model = DiscreteModel::build(name.as_str(), &model)
.map_err(|e| anyhow!(e.as_error_message(code)))?;
Self::from_discrete_model(&model, options, Some(code))
}
pub fn from_discrete_model(
model: &DiscreteModel,
options: CompilerOptions,
code: Option<&str>,
) -> Result<Self>
where
M: CodegenModuleCompile + CodegenModuleJit,
{
let mode = options.mode;
let mut module = M::from_discrete_model(model, options, None, T::as_real_type(), code)?;
let symbol_map = module.jit()?;
Self::new(module, symbol_map, mode)
}
pub fn supports_reverse_autodiff(&self) -> bool {
self.jit_grad_r_functions.is_some()
}
pub fn module(&self) -> &M {
&self._module
}
pub fn to_module(self) -> M {
self._module
}
pub fn get_tensors(&self) -> Vec<String> {
self.jit_get_tensor_functions
.data_map
.keys()
.cloned()
.collect()
}
pub fn get_tensor_data<'a>(&self, name: &str, data: &'a [T]) -> Option<&'a [T]> {
if let Some(get_func) = self.jit_get_tensor_functions.data_map.get(name) {
let mut tensor_data: *mut T = std::ptr::null_mut();
let mut tensor_size: u32 = 0;
unsafe {
(get_func)(
data.as_ptr(),
&mut tensor_data as *mut *mut T,
&mut tensor_size as *mut u32,
)
};
Some(unsafe {
std::slice::from_raw_parts(tensor_data, usize::try_from(tensor_size).unwrap())
})
} else {
None
}
}
pub fn get_constants(&self) -> Vec<String> {
self.jit_get_tensor_functions
.constant_map
.keys()
.cloned()
.collect()
}
pub fn get_constants_data(&self, name: &str) -> Option<&[T]> {
if let Some(get_func) = self.jit_get_tensor_functions.constant_map.get(name) {
let mut tensor_data: *const T = std::ptr::null_mut();
let mut tensor_size: u32 = 0;
unsafe {
(get_func)(
&mut tensor_data as *mut *const T,
&mut tensor_size as *mut u32,
)
};
Some(unsafe {
std::slice::from_raw_parts(tensor_data, usize::try_from(tensor_size).unwrap())
})
} else {
None
}
}
pub fn get_tensor_data_mut<'a>(&self, name: &str, data: &'a mut [T]) -> Option<&'a mut [T]> {
if let Some(get_func) = self.jit_get_tensor_functions.data_map.get(name) {
let mut tensor_data: *mut T = std::ptr::null_mut();
let mut tensor_size: u32 = 0;
unsafe {
(get_func)(
data.as_ptr(),
&mut tensor_data as *mut *mut T,
&mut tensor_size as *mut u32,
)
};
Some(unsafe {
std::slice::from_raw_parts_mut(tensor_data, usize::try_from(tensor_size).unwrap())
})
} else {
None
}
}
#[cfg(feature = "rayon")]
fn with_threading<F>(&self, f: F)
where
F: Fn(u32, u32) + Sync + Send,
{
#[cfg(feature = "rayon")]
if let (Some(thread_pool), Some(thread_lock)) = (&self.thread_pool, &self.thread_lock) {
let _lock = thread_lock.lock().unwrap();
let dim = thread_pool.current_num_threads();
unsafe {
(self.jit_functions.barrier_init.unwrap())();
}
thread_pool.broadcast(|ctx| {
let idx = ctx.index() as u32;
let dim = dim as u32;
// internal barriers in f use active spin-locks, so all threads
// must be available so the spin-locks can be released
f(idx, dim);
});
} else {
f(0, 1);
}
}
#[cfg(not(feature = "rayon"))]
fn with_threading<F>(&self, f: F)
where
F: Fn(u32, u32),
{
f(0, 1);
}
fn check_arg_len(&self, arg: &[T], expected_len: usize, name: &str) {
if arg.len() != expected_len {
panic!(
"Expected arg {} has len {}, got {}",
name,
expected_len,
arg.len()
);
}
}
fn check_state_len(&self, state: &[T], name: &str) {
self.check_arg_len(state, self.number_of_states, name);
}
fn check_stop_len(&self, stop: &[T], name: &str) {
self.check_arg_len(stop, self.number_of_stop, name);
}
fn check_data_len(&self, data: &[T], name: &str) {
self.check_arg_len(data, self.data_len(), name);
}
fn check_out_len(&self, out: &[T], name: &str) {
self.check_arg_len(out, self.number_of_outputs, name);
}
fn check_inputs_len(&self, inputs: &[T], name: &str) {
if inputs.len() != self.number_of_parameters {
panic!(
"Expected arg {} has len {}, got {}",
name,
self.number_of_parameters,
inputs.len()
);
}
}
fn set_constants(&mut self) {
self.with_threading(|i, dim| unsafe {
(self.jit_functions.set_constants)(i, dim);
});
}
pub fn set_u0(&self, yy: &mut [T], data: &mut [T]) {
self.check_state_len(yy, "yy");
self.with_threading(|i, dim| unsafe {
(self.jit_functions.set_u0)(yy.as_ptr() as *mut T, data.as_ptr() as *mut T, i, dim);
});
}
pub fn set_u0_sgrad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T]) {
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.set_u0_sgrad)(
yy.as_ptr(),
dyy.as_ptr() as *mut T,
data.as_ptr(),
ddata.as_ptr() as *mut T,
i,
dim,
);
});
}
pub fn set_u0_rgrad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T]) {
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_grad_r_functions
.as_ref()
.expect("module does not support reverse autograd")
.set_u0_rgrad)(
yy.as_ptr(),
dyy.as_ptr() as *mut T,
data.as_ptr(),
ddata.as_ptr() as *mut T,
i,
dim,
);
});
}
pub fn set_u0_grad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T]) {
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| {
unsafe {
(self.jit_grad_functions.set_u0_grad)(
yy.as_ptr(),
dyy.as_ptr() as *mut T,
data.as_ptr(),
ddata.as_ptr() as *mut T,
i,
dim,
)
};
})
}
pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T]) {
if self.number_of_stop == 0 {
panic!("Model does not have a stop function");
}
self.check_state_len(yy, "yy");
self.check_stop_len(stop, "stop");
self.check_data_len(data, "data");
self.with_threading(|i, dim| unsafe {
(self.jit_functions.calc_stop)(
t,
yy.as_ptr(),
data.as_ptr() as *mut T,
stop.as_ptr() as *mut T,
i,
dim,
)
});
}
#[allow(clippy::too_many_arguments)]
pub fn calc_stop_grad(
&self,
t: T,
yy: &[T],
dyy: &[T],
data: &[T],
ddata: &mut [T],
stop: &[T],
dstop: &mut [T],
) {
if self.number_of_stop == 0 {
panic!("Model does not have a stop function");
}
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_stop_len(stop, "stop");
self.check_stop_len(dstop, "dstop");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self.jit_grad_functions.stop_grad)(
t,
yy.as_ptr(),
dyy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
stop.as_ptr(),
dstop.as_ptr() as *mut T,
i,
dim,
)
});
}
#[allow(clippy::too_many_arguments)]
pub fn calc_stop_rgrad(
&self,
t: T,
yy: &[T],
dyy: &mut [T],
data: &[T],
ddata: &mut [T],
stop: &[T],
dstop: &mut [T],
) {
if self.number_of_stop == 0 {
panic!("Model does not have a stop function");
}
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_stop_len(stop, "stop");
self.check_stop_len(dstop, "dstop");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_grad_r_functions
.as_ref()
.expect("module does not support reverse autograd")
.stop_rgrad)(
t,
yy.as_ptr(),
dyy.as_ptr() as *mut T,
data.as_ptr(),
ddata.as_ptr() as *mut T,
stop.as_ptr(),
dstop.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn calc_stop_sgrad(
&self,
t: T,
yy: &[T],
data: &[T],
ddata: &mut [T],
stop: &[T],
dstop: &mut [T],
) {
if self.number_of_stop == 0 {
panic!("Model does not have a stop function");
}
self.check_state_len(yy, "yy");
self.check_stop_len(stop, "stop");
self.check_stop_len(dstop, "dstop");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.stop_sgrad)(
t,
yy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
stop.as_ptr(),
dstop.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn calc_stop_srgrad(
&self,
t: T,
yy: &[T],
data: &[T],
ddata: &mut [T],
stop: &[T],
dstop: &mut [T],
) {
if self.number_of_stop == 0 {
panic!("Model does not have a stop function");
}
self.check_state_len(yy, "yy");
self.check_stop_len(stop, "stop");
self.check_stop_len(dstop, "dstop");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_rev_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.stop_rgrad)(
t,
yy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
stop.as_ptr(),
dstop.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn reset(&self, t: T, yy: &[T], data: &mut [T], reset: &mut [T]) {
if reset.is_empty() {
return;
}
self.check_state_len(yy, "yy");
self.check_state_len(reset, "reset");
self.check_data_len(data, "data");
self.with_threading(|i, dim| unsafe {
(self.jit_functions.reset)(
t,
yy.as_ptr(),
data.as_ptr() as *mut T,
reset.as_ptr() as *mut T,
i,
dim,
)
});
}
#[allow(clippy::too_many_arguments)]
pub fn reset_grad(
&self,
t: T,
yy: &[T],
dyy: &[T],
data: &[T],
ddata: &mut [T],
reset: &[T],
dreset: &mut [T],
) {
if dreset.is_empty() {
return;
}
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_state_len(reset, "reset");
self.check_state_len(dreset, "dreset");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self.jit_grad_functions.reset_grad)(
t,
yy.as_ptr(),
dyy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
reset.as_ptr(),
dreset.as_ptr() as *mut T,
i,
dim,
)
});
}
#[allow(clippy::too_many_arguments)]
pub fn reset_rgrad(
&self,
t: T,
yy: &[T],
dyy: &mut [T],
data: &[T],
ddata: &mut [T],
reset: &[T],
dreset: &mut [T],
) {
if dreset.is_empty() {
return;
}
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_state_len(reset, "reset");
self.check_state_len(dreset, "dreset");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_grad_r_functions
.as_ref()
.expect("module does not support reverse autograd")
.reset_rgrad)(
t,
yy.as_ptr(),
dyy.as_ptr() as *mut T,
data.as_ptr(),
ddata.as_ptr() as *mut T,
reset.as_ptr(),
dreset.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn reset_sgrad(
&self,
t: T,
yy: &[T],
data: &[T],
ddata: &mut [T],
reset: &[T],
dreset: &mut [T],
) {
if dreset.is_empty() {
return;
}
self.check_state_len(yy, "yy");
self.check_state_len(reset, "reset");
self.check_state_len(dreset, "dreset");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.reset_sgrad)(
t,
yy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
reset.as_ptr(),
dreset.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn reset_srgrad(
&self,
t: T,
yy: &[T],
data: &[T],
ddata: &mut [T],
reset: &[T],
dreset: &mut [T],
) {
if dreset.is_empty() {
return;
}
self.check_state_len(yy, "yy");
self.check_state_len(reset, "reset");
self.check_state_len(dreset, "dreset");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_rev_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.reset_rgrad)(
t,
yy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
reset.as_ptr(),
dreset.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T]) {
self.check_state_len(yy, "yy");
self.check_state_len(rr, "rr");
self.check_data_len(data, "data");
self.with_threading(|i, dim| unsafe {
(self.jit_functions.rhs)(
t,
yy.as_ptr(),
data.as_ptr() as *mut T,
rr.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn has_mass(&self) -> bool {
self.has_mass
}
pub fn has_reset(&self) -> bool {
self.has_reset
}
pub fn mass(&self, t: T, v: &[T], data: &mut [T], mv: &mut [T]) {
if !self.has_mass {
panic!("Model does not have a mass function");
}
self.check_state_len(v, "v");
self.check_state_len(mv, "mv");
self.check_data_len(data, "data");
self.with_threading(|i, dim| unsafe {
(self.jit_functions.mass)(
t,
v.as_ptr(),
data.as_ptr() as *mut T,
mv.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn data_len(&self) -> usize {
self.data_size
}
pub fn get_new_data(&self) -> Vec<T> {
vec![T::zero(); self.data_len()]
}
#[allow(clippy::too_many_arguments)]
pub fn rhs_grad(
&self,
t: T,
yy: &[T],
dyy: &[T],
data: &[T],
ddata: &mut [T],
rr: &[T],
drr: &mut [T],
) {
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_state_len(rr, "rr");
self.check_state_len(drr, "drr");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self.jit_grad_functions.rhs_grad)(
t,
yy.as_ptr(),
dyy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
rr.as_ptr(),
drr.as_ptr() as *mut T,
i,
dim,
)
});
}
#[allow(clippy::too_many_arguments)]
pub fn rhs_rgrad(
&self,
t: T,
yy: &[T],
dyy: &mut [T],
data: &[T],
ddata: &mut [T],
rr: &[T],
drr: &mut [T],
) {
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_state_len(rr, "rr");
self.check_state_len(drr, "drr");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_grad_r_functions
.as_ref()
.expect("module does not support reverse autograd")
.rhs_rgrad)(
t,
yy.as_ptr(),
dyy.as_ptr() as *mut T,
data.as_ptr(),
ddata.as_ptr() as *mut T,
rr.as_ptr(),
drr.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn mass_rgrad(&self, t: T, dv: &mut [T], data: &[T], ddata: &mut [T], dmv: &mut [T]) {
self.check_state_len(dv, "dv");
self.check_state_len(dmv, "dmv");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_grad_r_functions
.as_ref()
.expect("module does not support reverse autograd")
.mass_rgrad)(
t,
std::ptr::null(),
dv.as_ptr() as *mut T,
data.as_ptr(),
ddata.as_ptr() as *mut T,
std::ptr::null(),
dmv.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn rhs_sgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T]) {
self.check_state_len(yy, "yy");
self.check_state_len(rr, "rr");
self.check_state_len(drr, "drr");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.rhs_sgrad)(
t,
yy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
rr.as_ptr(),
drr.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn rhs_srgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T]) {
self.check_state_len(yy, "yy");
self.check_state_len(rr, "rr");
self.check_state_len(drr, "drr");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_rev_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.rhs_rgrad)(
t,
yy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
rr.as_ptr(),
drr.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T]) {
self.check_state_len(yy, "yy");
self.check_data_len(data, "data");
self.check_out_len(out, "out");
self.with_threading(|i, dim| unsafe {
(self.jit_functions.calc_out)(
t,
yy.as_ptr(),
data.as_ptr() as *mut T,
out.as_ptr() as *mut T,
i,
dim,
)
});
}
#[allow(clippy::too_many_arguments)]
pub fn calc_out_grad(
&self,
t: T,
yy: &[T],
dyy: &[T],
data: &[T],
ddata: &mut [T],
out: &[T],
dout: &mut [T],
) {
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.check_out_len(out, "out");
self.check_out_len(dout, "dout");
self.with_threading(|i, dim| unsafe {
(self.jit_grad_functions.calc_out_grad)(
t,
yy.as_ptr(),
dyy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
out.as_ptr(),
dout.as_ptr() as *mut T,
i,
dim,
)
});
}
#[allow(clippy::too_many_arguments)]
pub fn calc_out_rgrad(
&self,
t: T,
yy: &[T],
dyy: &mut [T],
data: &[T],
ddata: &mut [T],
out: &[T],
dout: &mut [T],
) {
self.check_state_len(yy, "yy");
self.check_state_len(dyy, "dyy");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.check_out_len(out, "out");
self.check_out_len(dout, "dout");
self.with_threading(|i, dim| unsafe {
(self
.jit_grad_r_functions
.as_ref()
.expect("module does not support reverse autograd")
.calc_out_rgrad)(
t,
yy.as_ptr(),
dyy.as_ptr() as *mut T,
data.as_ptr(),
ddata.as_ptr() as *mut T,
out.as_ptr(),
dout.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn calc_out_sgrad(
&self,
t: T,
yy: &[T],
data: &[T],
ddata: &mut [T],
out: &[T],
dout: &mut [T],
) {
self.check_state_len(yy, "yy");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.check_out_len(out, "out");
self.check_out_len(dout, "dout");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.calc_out_sgrad)(
t,
yy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
out.as_ptr(),
dout.as_ptr() as *mut T,
i,
dim,
)
});
}
pub fn calc_out_srgrad(
&self,
t: T,
yy: &[T],
data: &[T],
ddata: &mut [T],
out: &[T],
dout: &mut [T],
) {
self.check_state_len(yy, "yy");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
self.check_out_len(out, "out");
self.check_out_len(dout, "dout");
self.with_threading(|i, dim| unsafe {
(self
.jit_sens_rev_grad_functions
.as_ref()
.expect("module does not support sens autograd")
.calc_out_rgrad)(
t,
yy.as_ptr(),
data.as_ptr(),
ddata.as_ptr() as *mut T,
out.as_ptr(),
dout.as_ptr() as *mut T,
i,
dim,
)
});
}
/// Get various dimensions of the model
///
/// # Returns
///
/// A tuple of the form `(n_states, n_inputs, n_outputs, n_data, n_stop, has_mass, has_reset)`
pub fn get_dims(&self) -> (usize, usize, usize, usize, usize, bool, bool) {
let mut n_states = 0u32;
let mut n_inputs = 0u32;
let mut n_outputs = 0u32;
let mut n_data = 0u32;
let mut n_stop = 0u32;
let mut has_mass = 0u32;
let mut has_reset = 0u32;
unsafe {
(self.jit_functions.get_dims)(
&mut n_states,
&mut n_inputs,
&mut n_outputs,
&mut n_data,
&mut n_stop,
&mut has_mass,
&mut has_reset,
)
};
(
n_states as usize,
n_inputs as usize,
n_outputs as usize,
n_data as usize,
n_stop as usize,
has_mass != 0,
has_reset != 0,
)
}
pub fn set_inputs(&self, inputs: &[T], data: &mut [T], model_index: u32) {
self.check_inputs_len(inputs, "inputs");
self.check_data_len(data, "data");
unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr(), model_index) };
}
pub fn get_inputs(&self, inputs: &mut [T], data: &[T]) {
self.check_inputs_len(inputs, "inputs");
self.check_data_len(data, "data");
unsafe { (self.jit_functions.get_inputs)(inputs.as_mut_ptr(), data.as_ptr()) };
}
pub fn set_inputs_grad(
&self,
inputs: &[T],
dinputs: &[T],
data: &[T],
ddata: &mut [T],
model_index: u32,
) {
self.check_inputs_len(inputs, "inputs");
self.check_inputs_len(dinputs, "dinputs");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
unsafe {
(self.jit_grad_functions.set_inputs_grad)(
inputs.as_ptr(),
dinputs.as_ptr(),
data.as_ptr(),
ddata.as_mut_ptr(),
model_index,
)
};
}
pub fn set_inputs_rgrad(
&self,
inputs: &[T],
dinputs: &mut [T],
data: &[T],
ddata: &mut [T],
model_index: u32,
) {
self.check_inputs_len(inputs, "inputs");
self.check_inputs_len(dinputs, "dinputs");
self.check_data_len(data, "data");
self.check_data_len(ddata, "ddata");
unsafe {
(self
.jit_grad_r_functions
.as_ref()
.expect("module does not support reverse autograd")
.set_inputs_rgrad)(
inputs.as_ptr(),
dinputs.as_mut_ptr(),
data.as_ptr(),
ddata.as_mut_ptr(),
model_index,
)
};
}
pub fn set_id(&self, id: &mut [T]) {
let (n_states, _, _, _, _, _, _) = self.get_dims();
if n_states != id.len() {
panic!("Expected {} states, got {}", n_states, id.len());
}
unsafe { (self.jit_functions.set_id)(id.as_mut_ptr()) };
}
pub fn number_of_states(&self) -> usize {
self.number_of_states
}
pub fn number_of_parameters(&self) -> usize {
self.number_of_parameters
}
pub fn number_of_outputs(&self) -> usize {
self.number_of_outputs
}
}
#[cfg(test)]
mod tests {
use std::{sync::Arc, thread};
use super::CompilerMode;
use crate::{
discretise::DiscreteModel,
execution::{
compiler::CompilerOptions,
module::{CodegenModule, CodegenModuleCompile, CodegenModuleJit},
scalar::Scalar,
},
parser::parse_ds_string,
Compiler,
};
use approx::{assert_relative_eq, RelativeEq};
use num_traits::ToPrimitive;
use paste::paste;
/// Macro to generate test functions for all combinations of backend (cranelift/llvm) and scalar type (f32/f64)
///
/// Usage: `generate_tests!(test_name, generic_test_function);`
///
/// This will generate 4 test functions:
/// - {test_name}_cranelift_f64
/// - {test_name}_cranelift_f32
/// - {test_name}_llvm_f64
/// - {test_name}_llvm_f32
///
/// Example:
/// ```
/// fn my_test<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar>() { ... }
/// generate_tests!(my_test);
/// ```
macro_rules! generate_tests {
($test_fn:ident) => {
generate_tests!(@impl $test_fn, cranelift_f64, crate::CraneliftJitModule, f64, "cranelift");
generate_tests!(@impl $test_fn, cranelift_f32, crate::CraneliftJitModule, f32, "cranelift");
generate_tests!(@impl $test_fn, llvm_f64, crate::LlvmModule, f64, "llvm");
generate_tests!(@impl $test_fn, llvm_f32, crate::LlvmModule, f32, "llvm");
};
(@impl $test_fn:ident, $variant:ident, $module:ty, $scalar:ty, $feature:literal) => {
paste! {
#[cfg(feature = $feature)]
#[test]
fn [<$test_fn _ $variant>]() {
$test_fn::<$module, $scalar>();
}
}
};
}
generate_tests!(test_constants);
#[allow(dead_code)]
fn test_constants<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
let full_text = "
in { a = 1 }
b { 2 }
a2 { a * a }
b2 { b * b }
u_i { y = 1 }
F_i { y * b }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
// b and b2 should already be set
let mut data = compiler.get_new_data();
let b = compiler.get_constants_data("b").unwrap();
let b2 = compiler.get_constants_data("b2").unwrap();
assert_relative_eq!(b[0], T::from_f64(2.0).unwrap());
assert_relative_eq!(b2[0], T::from_f64(4.0).unwrap());
// a and a2 should not be set (be 0)
let a = compiler.get_tensor_data("in", &data).unwrap();
let a2 = compiler.get_tensor_data("a2", &data).unwrap();
assert_relative_eq!(a[0], T::zero());
assert_relative_eq!(a2[0], T::zero());
// set the inputs and u0
let inputs = vec![T::one()];
compiler.set_inputs(&inputs, data.as_mut_slice(), 0);
let mut u0 = vec![T::zero()];
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
// now a and a2 should be set
let a = compiler.get_tensor_data("in", &data).unwrap();
let a2 = compiler.get_tensor_data("a2", &data).unwrap();
assert_relative_eq!(a[0], T::one());
assert_relative_eq!(a2[0], T::one());
}
generate_tests!(test_from_discrete_str_common);
#[allow(dead_code)]
fn test_from_discrete_str_common<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq,
>() {
let text1 = "
u { y = 1 }
F { -y }
out { y }
";
let text2 = "
p { 1 }
u { p }
F { -u }
out { u }
";
for text in [text2, text1] {
let compiler = Compiler::<M, T>::from_discrete_str(text, Default::default()).unwrap();
let (n_states, n_inputs, n_outputs, _n_data, n_stop, has_mass, has_reset) =
compiler.get_dims();
assert_eq!(n_states, 1);
assert_eq!(n_inputs, 0);
assert_eq!(n_outputs, 1);
assert_eq!(n_stop, 0);
assert!(!has_mass);
assert!(!has_reset);
let mut u0 = vec![T::zero()];
let mut res = vec![T::zero()];
let mut data = compiler.get_new_data();
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
assert_relative_eq!(u0.as_slice(), vec![T::one()].as_slice());
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
res.as_mut_slice(),
);
assert_relative_eq!(res.as_slice(), vec![-T::one()].as_slice());
}
}
generate_tests!(test_stop);
generate_tests!(test_stop_gradients);
generate_tests!(test_reset);
generate_tests!(test_reset_gradients);
generate_tests!(test_reset_without_reset_tensor_is_noop);
#[allow(dead_code)]
fn test_stop<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
let full_text = "
u_i {
y = 1,
}
dudt_i {
dydt = 0,
}
M_i {
dydt,
}
F_i {
y * (1 - y),
}
stop_i {
y - 0.5,
}
out {
y,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut u0 = vec![T::one()];
let mut res = vec![T::zero()];
let mut stop = vec![T::zero()];
let mut data = compiler.get_new_data();
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
res.as_mut_slice(),
);
compiler.calc_stop(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
stop.as_mut_slice(),
);
assert_relative_eq!(stop[0], T::from_f64(0.5).unwrap());
assert_eq!(stop.len(), 1);
}
#[allow(dead_code)]
fn test_stop_gradients<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
let full_text = "
in {
a = 1,
}
u_i {
y = a,
z = 2,
}
F_i {
y,
z,
}
stop_i {
2 * y + a,
z + a,
}
out_i {
y,
z,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut data = compiler.get_new_data();
let inputs = vec![T::from_f64(3.0).unwrap()];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0);
let mut yy = vec![T::zero(), T::zero()];
compiler.set_u0(yy.as_mut_slice(), data.as_mut_slice());
let mut stop = vec![T::zero(), T::zero()];
compiler.calc_stop(
T::zero(),
yy.as_slice(),
data.as_mut_slice(),
stop.as_mut_slice(),
);
assert_relative_eq!(stop[0], T::from_f64(9.0).unwrap());
assert_relative_eq!(stop[1], T::from_f64(5.0).unwrap());
let mut ddata = compiler.get_new_data();
let dinputs = vec![T::one()];
compiler.set_inputs_grad(
inputs.as_slice(),
dinputs.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
0,
);
let dyy = vec![T::one(), T::zero()];
let mut dstop = vec![T::zero(), T::zero()];
compiler.calc_stop_grad(
T::zero(),
yy.as_slice(),
dyy.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
stop.as_slice(),
dstop.as_mut_slice(),
);
assert_relative_eq!(dstop[0], T::from_f64(3.0).unwrap());
assert_relative_eq!(dstop[1], T::from_f64(1.0).unwrap());
if compiler.supports_reverse_autodiff() {
let mut dyy_rev = vec![T::zero(), T::zero()];
let mut ddata_rev = compiler.get_new_data();
let mut dstop_rev = vec![T::one(), T::one()];
compiler.calc_stop_rgrad(
T::zero(),
yy.as_slice(),
dyy_rev.as_mut_slice(),
data.as_slice(),
ddata_rev.as_mut_slice(),
stop.as_slice(),
dstop_rev.as_mut_slice(),
);
assert_relative_eq!(dyy_rev[0], T::from_f64(2.0).unwrap());
assert_relative_eq!(dyy_rev[1], T::one());
let mut dinputs_rev = vec![T::zero(); inputs.len()];
compiler.set_inputs_rgrad(
inputs.as_slice(),
dinputs_rev.as_mut_slice(),
data.as_slice(),
ddata_rev.as_mut_slice(),
0,
);
assert_relative_eq!(dinputs_rev[0], T::from_f64(2.0).unwrap());
let mut ddata_s = compiler.get_new_data();
let dinputs_s = vec![T::one(); inputs.len()];
compiler.set_inputs(dinputs_s.as_slice(), ddata_s.as_mut_slice(), 0);
let mut dstop_s = vec![T::zero(), T::zero()];
compiler.calc_stop_sgrad(
T::zero(),
yy.as_slice(),
data.as_slice(),
ddata_s.as_mut_slice(),
stop.as_slice(),
dstop_s.as_mut_slice(),
);
assert_relative_eq!(dstop_s[0], T::one());
assert_relative_eq!(dstop_s[1], T::one());
let mut ddata_sr = compiler.get_new_data();
let mut dstop_sr = vec![T::one(), T::one()];
compiler.calc_stop_srgrad(
T::zero(),
yy.as_slice(),
data.as_slice(),
ddata_sr.as_mut_slice(),
stop.as_slice(),
dstop_sr.as_mut_slice(),
);
let mut dinputs_sr = vec![T::zero(); inputs.len()];
compiler.set_inputs_rgrad(
inputs.as_slice(),
dinputs_sr.as_mut_slice(),
data.as_slice(),
ddata_sr.as_mut_slice(),
0,
);
assert_relative_eq!(dinputs_sr[0], T::from_f64(2.0).unwrap());
}
}
#[allow(dead_code)]
fn test_reset<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
let full_text = "
u_i {
y = 1,
z = 2,
}
dudt_i {
dydt = 0,
dzdt = 0,
}
M_i {
dydt,
dzdt,
}
F_i {
y,
z,
}
reset_i {
2 * y,
z + 10,
}
stop_i {
y - 0.5,
}
out_i {
y,
z,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let (_n_states, _n_inputs, _n_outputs, _n_data, _n_stop, _has_mass, has_reset) =
compiler.get_dims();
assert!(has_reset);
let mut u0 = vec![T::zero(), T::zero()];
let mut reset = vec![T::zero(), T::zero()];
let mut data = compiler.get_new_data();
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.reset(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
reset.as_mut_slice(),
);
assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap());
assert_relative_eq!(u0[1], T::from_f64(2.0).unwrap());
assert_relative_eq!(reset[0], T::from_f64(2.0).unwrap());
assert_relative_eq!(reset[1], T::from_f64(12.0).unwrap());
assert_eq!(reset.len(), 2);
}
#[allow(dead_code)]
fn test_reset_gradients<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
let full_text = "
in {
a = 1,
}
u_i {
y = a,
z = 2,
}
F_i {
y,
z,
}
reset_i {
2 * y + a,
z + a,
}
stop_i {
y - 0.5,
z - 1,
}
out_i {
y,
z,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut data = compiler.get_new_data();
let inputs = vec![T::from_f64(3.0).unwrap()];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0);
let mut yy = vec![T::zero(), T::zero()];
compiler.set_u0(yy.as_mut_slice(), data.as_mut_slice());
let mut reset = vec![T::zero(), T::zero()];
compiler.reset(
T::zero(),
yy.as_slice(),
data.as_mut_slice(),
reset.as_mut_slice(),
);
assert_relative_eq!(reset[0], T::from_f64(9.0).unwrap());
assert_relative_eq!(reset[1], T::from_f64(5.0).unwrap());
let mut ddata = compiler.get_new_data();
let dinputs = vec![T::one()];
compiler.set_inputs_grad(
inputs.as_slice(),
dinputs.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
0,
);
let dyy = vec![T::one(), T::zero()];
let mut dreset = vec![T::zero(), T::zero()];
compiler.reset_grad(
T::zero(),
yy.as_slice(),
dyy.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
reset.as_slice(),
dreset.as_mut_slice(),
);
assert_relative_eq!(dreset[0], T::from_f64(3.0).unwrap());
assert_relative_eq!(dreset[1], T::from_f64(1.0).unwrap());
if compiler.supports_reverse_autodiff() {
let mut dyy_rev = vec![T::zero(), T::zero()];
let mut ddata_rev = compiler.get_new_data();
let mut dreset_rev = vec![T::one(), T::one()];
compiler.reset_rgrad(
T::zero(),
yy.as_slice(),
dyy_rev.as_mut_slice(),
data.as_slice(),
ddata_rev.as_mut_slice(),
reset.as_slice(),
dreset_rev.as_mut_slice(),
);
assert_relative_eq!(dyy_rev[0], T::from_f64(2.0).unwrap());
assert_relative_eq!(dyy_rev[1], T::one());
let mut dinputs_rev = vec![T::zero(); inputs.len()];
compiler.set_inputs_rgrad(
inputs.as_slice(),
dinputs_rev.as_mut_slice(),
data.as_slice(),
ddata_rev.as_mut_slice(),
0,
);
assert_relative_eq!(dinputs_rev[0], T::from_f64(2.0).unwrap());
let mut ddata_s = compiler.get_new_data();
let dinputs_s = vec![T::one(); inputs.len()];
compiler.set_inputs(dinputs_s.as_slice(), ddata_s.as_mut_slice(), 0);
let mut dreset_s = vec![T::zero(), T::zero()];
compiler.reset_sgrad(
T::zero(),
yy.as_slice(),
data.as_slice(),
ddata_s.as_mut_slice(),
reset.as_slice(),
dreset_s.as_mut_slice(),
);
assert_relative_eq!(dreset_s[0], T::one());
assert_relative_eq!(dreset_s[1], T::one());
let mut ddata_sr = compiler.get_new_data();
let mut dreset_sr = vec![T::one(), T::one()];
compiler.reset_srgrad(
T::zero(),
yy.as_slice(),
data.as_slice(),
ddata_sr.as_mut_slice(),
reset.as_slice(),
dreset_sr.as_mut_slice(),
);
let mut dinputs_sr = vec![T::zero(); inputs.len()];
compiler.set_inputs_rgrad(
inputs.as_slice(),
dinputs_sr.as_mut_slice(),
data.as_slice(),
ddata_sr.as_mut_slice(),
0,
);
assert_relative_eq!(dinputs_sr[0], T::from_f64(2.0).unwrap());
}
}
#[allow(dead_code)]
fn test_reset_without_reset_tensor_is_noop<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq,
>() {
let full_text = "
u_i {
y = 1,
}
F_i {
y,
}
out_i {
y,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut u0 = vec![T::zero()];
let mut reset: Vec<T> = vec![];
let mut data = compiler.get_new_data();
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.reset(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
reset.as_mut_slice(),
);
assert_eq!(reset.len(), 0);
}
generate_tests!(test_out_depends_on_internal_tensor);
generate_tests!(test_model_index_n_depends_on_model_index);
generate_tests!(test_model_index_n_dynamic_index_grad);
generate_tests!(test_model_index_n_dynamic_range_width_const);
#[allow(dead_code)]
fn test_model_index_n_depends_on_model_index<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq,
>() {
// RED test for issue #112:
// - `N` is a reserved model index
// - `%` is supported in expressions
// - tensor indexing can use expression indices, e.g. amp_i[N % 2]
//
// Expected behavior once implemented:
// - N is taken from model_index.
// - model_index = 0 => N % 2 = 0
// - model_index = 1 => N % 2 = 1
let full_text = "
amp_i { 0, 10 }
dur_i { 10, 5 }
u_i { x = 1, tclock = 0 }
F_i { amp_i[N % 2] - x, 1 }
stop_i { dur_i[N % 2] - tclock }
out_i { x, tclock }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut u0 = vec![T::zero(); 2];
let mut rr0 = vec![T::zero(); 2];
let mut rr1 = vec![T::zero(); 2];
let mut stop0 = vec![T::zero(); 1];
let mut stop1 = vec![T::zero(); 1];
let mut data = compiler.get_new_data();
compiler.set_inputs(&[], data.as_mut_slice(), 0);
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
rr0.as_mut_slice(),
);
compiler.calc_stop(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
stop0.as_mut_slice(),
);
compiler.set_inputs(&[], data.as_mut_slice(), 1);
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
rr1.as_mut_slice(),
);
compiler.calc_stop(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
stop1.as_mut_slice(),
);
assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap());
assert_relative_eq!(u0[1], T::from_f64(0.0).unwrap());
assert_relative_eq!(rr0[0], T::from_f64(-1.0).unwrap());
assert_relative_eq!(rr0[1], T::from_f64(1.0).unwrap());
assert_relative_eq!(stop0[0], T::from_f64(10.0).unwrap());
assert_relative_eq!(rr1[0], T::from_f64(9.0).unwrap());
assert_relative_eq!(rr1[1], T::from_f64(1.0).unwrap());
assert_relative_eq!(stop1[0], T::from_f64(5.0).unwrap());
assert_ne!(rr0[0], rr1[0]);
assert_ne!(stop0[0], stop1[0]);
}
#[allow(dead_code)]
fn test_model_index_n_dynamic_index_grad<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq,
>() {
let full_text = "
u_i { y = 3, z = 5 }
F_i { u_i[N % 2], 0 }
out_i { y, z }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut u0 = vec![T::zero(); 2];
let mut rr0 = vec![T::zero(); 2];
let mut rr1 = vec![T::zero(); 2];
let mut drr0 = vec![T::zero(); 2];
let mut drr1 = vec![T::zero(); 2];
let mut data = compiler.get_new_data();
let mut ddata0 = compiler.get_new_data();
let mut ddata1 = compiler.get_new_data();
compiler.set_inputs(&[], data.as_mut_slice(), 0);
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
rr0.as_mut_slice(),
);
compiler.set_inputs(&[], data.as_mut_slice(), 1);
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
rr1.as_mut_slice(),
);
assert_relative_eq!(rr0[0], T::from_f64(3.0).unwrap());
assert_relative_eq!(rr1[0], T::from_f64(5.0).unwrap());
let dyy0 = vec![T::one(), T::zero()];
let dyy1 = vec![T::zero(), T::one()];
compiler.set_inputs(&[], data.as_mut_slice(), 0);
compiler.rhs_grad(
T::zero(),
u0.as_slice(),
dyy0.as_slice(),
data.as_slice(),
ddata0.as_mut_slice(),
rr0.as_slice(),
drr0.as_mut_slice(),
);
compiler.set_inputs(&[], data.as_mut_slice(), 1);
compiler.rhs_grad(
T::zero(),
u0.as_slice(),
dyy1.as_slice(),
data.as_slice(),
ddata1.as_mut_slice(),
rr1.as_slice(),
drr1.as_mut_slice(),
);
assert_relative_eq!(drr0[0], T::one());
assert_relative_eq!(drr0[1], T::zero());
assert_relative_eq!(drr1[0], T::one());
assert_relative_eq!(drr1[1], T::zero());
}
#[allow(dead_code)]
fn test_model_index_n_dynamic_range_width_const<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq,
>() {
let full_text = "
amp_i { 0, 10, 20 }
u_i { x = 1, y = 1 }
F_i { amp_i[N:N+2] - u_i }
out_i { x, y }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut u0 = vec![T::zero(); 2];
let mut rr0 = vec![T::zero(); 2];
let mut rr1 = vec![T::zero(); 2];
let mut data = compiler.get_new_data();
compiler.set_inputs(&[], data.as_mut_slice(), 0);
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
rr0.as_mut_slice(),
);
compiler.set_inputs(&[], data.as_mut_slice(), 1);
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
rr1.as_mut_slice(),
);
assert_relative_eq!(rr0[0], T::from_f64(-1.0).unwrap());
assert_relative_eq!(rr0[1], T::from_f64(9.0).unwrap());
assert_relative_eq!(rr1[0], T::from_f64(9.0).unwrap());
assert_relative_eq!(rr1[1], T::from_f64(19.0).unwrap());
}
#[allow(dead_code)]
fn test_out_depends_on_internal_tensor<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq,
>() {
let full_text = "
u_i { y = 1 }
twoy_i { 2 * y }
F_i { y * (1 - y), }
out_i { twoy_i }
stop_i { twoy_i - 0.5 }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut u0 = vec![T::one()];
let mut data = compiler.get_new_data();
// need this to set the constants
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
let mut out = vec![T::zero()];
compiler.calc_out(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
out.as_mut_slice(),
);
assert_relative_eq!(out[0], T::from_f64(2.).unwrap());
u0[0] = T::from_f64(2.).unwrap();
compiler.calc_out(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
out.as_mut_slice(),
);
assert_relative_eq!(out[0], T::from_f64(4.).unwrap());
let mut stop = vec![T::zero()];
compiler.calc_stop(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
stop.as_mut_slice(),
);
assert_relative_eq!(stop[0], T::from_f64(3.5).unwrap());
u0[0] = T::from_f64(0.5).unwrap();
compiler.calc_stop(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
stop.as_mut_slice(),
);
assert_relative_eq!(stop[0], T::from_f64(0.5).unwrap());
}
#[cfg(all(feature = "cranelift", not(target_arch = "wasm32")))]
#[test]
fn test_vector_add_scalar_cranelift() {
let n = 1;
let u = vec![1.0; n];
let full_text = format!(
"
u_i {{
{}
}}
F_i {{
u_i + 1.0,
}}
out_i {{
u_i
}}
",
(0..n)
.map(|i| format!("x{} = {},", i, u[i]))
.collect::<Vec<_>>()
.join("\n"),
);
let model = parse_ds_string(&full_text).unwrap();
let name = "$name";
let discrete_model = DiscreteModel::build(name, &model).unwrap();
env_logger::builder().is_test(true).try_init().unwrap();
let _compiler = Compiler::<crate::CraneliftJitModule, f64>::from_discrete_model(
&discrete_model,
Default::default(),
Some(&full_text),
)
.unwrap();
}
#[allow(dead_code)]
fn tensor_test_common<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq + ToPrimitive,
>(
discrete_model: &DiscreteModel,
tensor_name: &str,
mode: CompilerMode,
) -> Vec<Vec<f64>> {
let compiler = Compiler::<M, T>::from_discrete_model(
discrete_model,
CompilerOptions { mode, debug: false },
None,
)
.unwrap();
tensor_test_common_impl(compiler, tensor_name)
.into_iter()
.map(|v| {
v.into_iter()
.map(|x: T| x.to_f64().unwrap())
.collect::<Vec<f64>>()
})
.collect()
}
#[allow(dead_code)]
fn tensor_test_common_impl<M: CodegenModule, T: Scalar + RelativeEq>(
compiler: Compiler<M, T>,
tensor_name: &str,
) -> Vec<Vec<T>> {
let (n_states, n_inputs, n_outputs, _n_data, _n_stop, _has_mass, _has_reset) =
compiler.get_dims();
let mut u0 = vec![T::one(); n_states];
let mut res = vec![T::zero(); n_states];
let mut data = compiler.get_new_data();
let mut results = Vec::new();
let inputs = vec![T::one(); n_inputs];
let mut out = vec![T::zero(); n_outputs];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0);
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
res.as_mut_slice(),
);
compiler.calc_out(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
out.as_mut_slice(),
);
let (tensor_len, tensor_is_constant) =
if let Some(tensor_data) = compiler.get_tensor_data(tensor_name, data.as_slice()) {
results.push(tensor_data.to_vec());
(tensor_data.len(), false)
} else if let Some(tensor_data) = compiler.get_constants_data(tensor_name) {
results.push(tensor_data.to_vec());
(tensor_data.len(), true)
} else {
panic!(
"{} is not a valid tensor name, tensors are {:?} and constants are {:?}",
tensor_name,
compiler.get_tensors(),
compiler.get_constants()
);
};
// forward mode
let mut dinputs = vec![T::zero(); n_inputs];
dinputs.fill(T::one());
let mut ddata = compiler.get_new_data();
let mut du0 = vec![T::zero(); n_states];
let mut dres = vec![T::zero(); n_states];
let mut dout = vec![T::zero(); n_outputs];
compiler.set_inputs_grad(
inputs.as_slice(),
dinputs.as_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
0,
);
compiler.set_u0_grad(
u0.as_mut_slice(),
du0.as_mut_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
);
compiler.rhs_grad(
T::zero(),
u0.as_slice(),
du0.as_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
res.as_mut_slice(),
dres.as_mut_slice(),
);
compiler.calc_out_grad(
T::zero(),
u0.as_slice(),
du0.as_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
out.as_slice(),
dout.as_mut_slice(),
);
if let Some(tensor_data) = compiler.get_tensor_data(tensor_name, ddata.as_slice()) {
results.push(tensor_data.to_vec());
} else {
results.push(vec![T::zero(); tensor_len]);
}
// reverse-mode
if compiler.supports_reverse_autodiff() && !tensor_is_constant {
let mut ddata = compiler.get_new_data();
let dtensor = compiler
.get_tensor_data_mut(tensor_name, ddata.as_mut_slice())
.unwrap();
dtensor.fill(T::one());
let mut du0 = vec![T::zero(); n_states];
let mut dres = vec![T::zero(); n_states];
let mut dinputs = vec![T::zero(); n_inputs];
let mut dout = vec![T::zero(); n_outputs];
// reverse pass (already done the forward pass)
compiler.calc_out_rgrad(
T::zero(),
u0.as_slice(),
du0.as_mut_slice(),
data.as_slice(),
ddata.as_mut_slice(),
out.as_slice(),
dout.as_mut_slice(),
);
compiler.rhs_rgrad(
T::zero(),
u0.as_slice(),
du0.as_mut_slice(),
data.as_slice(),
ddata.as_mut_slice(),
res.as_slice(),
dres.as_mut_slice(),
);
compiler.set_u0_rgrad(
u0.as_mut_slice(),
du0.as_mut_slice(),
data.as_slice(),
ddata.as_mut_slice(),
);
compiler.get_inputs(dinputs.as_mut_slice(), ddata.as_slice());
results.push(dinputs.to_vec());
// forward mode sens (rhs)
let mut ddata = compiler.get_new_data();
let mut dres = vec![T::zero(); n_states];
let dinputs = vec![T::one(); n_inputs];
compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice(), 0);
compiler.rhs_sgrad(
T::zero(),
u0.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
res.as_slice(),
dres.as_mut_slice(),
);
results.push(
compiler
.get_tensor_data(tensor_name, ddata.as_slice())
.unwrap()
.to_vec(),
);
// forward mode sens (calc_out)
let mut ddata = compiler.get_new_data();
let dinputs = vec![T::one(); n_inputs];
compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice(), 0);
compiler.calc_out_sgrad(
T::zero(),
u0.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
out.as_slice(),
dout.as_mut_slice(),
);
results.push(
compiler
.get_tensor_data(tensor_name, ddata.as_slice())
.unwrap()
.to_vec(),
);
// reverse mode sens (rhs)
let mut ddata = compiler.get_new_data();
let dtensor = compiler
.get_tensor_data_mut(tensor_name, ddata.as_mut_slice())
.unwrap();
dtensor.fill(T::one());
let mut dres = vec![T::zero(); n_states];
let mut dinputs = vec![T::zero(); n_inputs];
compiler.rhs_srgrad(
T::zero(),
u0.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
res.as_slice(),
dres.as_mut_slice(),
);
compiler.set_inputs_rgrad(
inputs.as_slice(),
dinputs.as_mut_slice(),
data.as_slice(),
ddata.as_mut_slice(),
0,
);
results.push(dinputs.to_vec());
// reverse mode sens (calc_out)
let mut ddata = compiler.get_new_data();
let dtensor = compiler
.get_tensor_data_mut(tensor_name, ddata.as_mut_slice())
.unwrap();
dtensor.fill(T::one());
let mut dinputs = vec![T::zero(); n_inputs];
compiler.calc_out_srgrad(
T::zero(),
u0.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
out.as_slice(),
dout.as_mut_slice(),
);
compiler.set_inputs_rgrad(
inputs.as_slice(),
dinputs.as_mut_slice(),
data.as_slice(),
ddata.as_mut_slice(),
0,
);
results.push(dinputs.to_vec());
} else {
results.push(vec![T::zero(); n_inputs]);
results.push(vec![T::zero(); tensor_len]);
results.push(vec![T::zero(); tensor_len]);
results.push(vec![T::zero(); n_inputs]);
results.push(vec![T::zero(); n_inputs]);
}
results
}
macro_rules! tensor_test {
($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr,)*) => {
paste! {
tensor_test_typed! {
$([<$name _f64>] : $text expect $tensor_name $expected_value ; f64,)*
$([<$name _f32>] : $text expect $tensor_name $expected_value ; f32,)*
}
}
}
}
macro_rules! tensor_test_typed {
($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr ; $scalar_type:ty ,)*) => {
$(
#[test]
fn $name() {
let full_text = format!("
{}
u_i {{
y = 1,
}}
F_i {{
y,
}}
out_i {{
y,
}}
", $text);
let model = parse_ds_string(full_text.as_str()).unwrap();
#[allow(unused_variables)]
let discrete_model = match DiscreteModel::build("$name", &model) {
Ok(model) => model,
Err(e) => {
panic!("{}", e.as_error_message(full_text.as_str()));
}
};
#[cfg(feature = "llvm")]
{
use crate::execution::llvm::codegen::LlvmModule;
let results = tensor_test_common::<LlvmModule, $scalar_type>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
}
#[cfg(feature = "cranelift")]
{
let results = tensor_test_common::<crate::CraneliftJitModule, $scalar_type>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
}
#[cfg(not(feature = "cranelift"))]
{
let model = parse_ds_string(full_text.as_str()).unwrap();
match DiscreteModel::build("$name", &model) {
Ok(model) => model,
Err(e) => {
panic!("{}", e.as_error_message(full_text.as_str()));
}
};
}
#[cfg(feature = "rayon")]
{
#[cfg(feature = "cranelift")]
{
let results = tensor_test_common::<crate::CraneliftJitModule, $scalar_type>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
}
#[cfg(feature = "llvm")]
{
use crate::execution::llvm::codegen::LlvmModule;
let results = tensor_test_common::<LlvmModule, $scalar_type>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
}
}
}
)*
}
}
#[cfg(not(any(feature = "inkwell-191", feature = "inkwell-181")))]
tensor_test! {
indexing2: "a_i { 0.0, 1.0, 2.0, 3.0 } r_i { a_i[1:3] }" expect "r" vec![1.0, 2.0],
indexing3: "a_i { 0.0, 1.0, 2.0, 3.0 } r_i { a_i[1..3] }" expect "r" vec![1.0, 2.0],
indexing_mat_mul: "a_ij { (0:2, 0:2): 1.0 } b_j { 1, 2, 3 } r_i { a_ij * b_j[1:3] }" expect "r" vec![5.0, 5.0],
}
tensor_test_typed! {
arccosh_function_f64: "r { arccosh(2) }" expect "r" vec![f64::acosh(2.0)] ; f64,
arcosh_function_f32: "r { arccosh(2) }" expect "r" vec![f32::acosh(2.0_f32).into()] ; f32,
exp_function_f64: "r { exp(2) }" expect "r" vec![f64::exp(2.0)] ; f64,
exp_function_f32: "r { exp(2) }" expect "r" vec![f32::exp(2.0_f32).into()] ; f32,
pow_function_f64: "r { pow(4.3245, 0.5) }" expect "r" vec![f64::powf(4.3245, 0.5)] ; f64,
pow_function_f32: "r { pow(4.3245, 0.5) }" expect "r" vec![f32::powf(4.3245_f32, 0.5).into()] ; f32,
tan_function_f64: "r { tan(0.234) }" expect "r" vec![f64::tan(0.234)] ; f64,
arcsinh_function_f64: "r { arcsinh(0.5) }" expect "r" vec![f64::asinh(0.5)] ; f64,
arcsinh_function_f32: "r { arcsinh(0.5) }" expect "r" vec![f32::asinh(0.5_f32).into()] ; f32,
tanh_function_f64: "r { tanh(0.5) }" expect "r" vec![f64::tanh(0.5)] ; f64,
// todo: why does this fail?
//tanh_function_f32: "r { tanh(0.5) }" expect "r" vec![f32::tanh(0.5_f32).into()] ; f32,
sinh_function_f64: "r { sinh(0.5) }" expect "r" vec![f64::sinh(0.5)] ; f64,
sinh_function_f32: "r { sinh(0.5) }" expect "r" vec![f32::sinh(0.5_f32).into()] ; f32,
cosh_function_f64: "r { cosh(0.5) }" expect "r" vec![f64::cosh(0.5)] ; f64,
cosh_function_f32: "r { cosh(0.5) }" expect "r" vec![f32::cosh(0.5_f32).into()] ; f32,
exp_function_time: "r { exp(t) }" expect "r" vec![f64::exp(0.0)] ; f64,
exp_function_time_f32: "r { exp(t) }" expect "r" vec![f32::exp(0.0_f32).into()] ; f32,
sigmoid_function_f64: "r { sigmoid(0.1) }" expect "r" vec![1.0 / (1.0 + f64::exp(-0.1))] ; f64,
sigmoid_function_f32: "r { sigmoid(0.1) }" expect "r" vec![ (1.0_f32 / (1.0_f32 + f32::exp(-0.1_f32))).into()] ; f32,
expression: "r_i {2 + 3, 3 * 2, arcsinh(1.2 + 1.0 / max(1.2, 1.0) * 2.0 + tanh(2.0))}" expect "r" vec![5., 6., f64::asinh(1.2 + 1.0 / f64::max(1.2, 1.0) * 2.0 + f64::tanh(2.0))] ; f64,
pybamm_expression: "
constant0_i { (0:19): 0.0, (19:20): 0.0006810238128045524,}
constant1_i { (0:19): 0.0, (19:20): -0.0011634665332403958,}
constant2_ij { (0,18): -25608.96286546366, (0,19): 76826.88859639116,}
constant3_ij {(0,18): -0.4999999999999983, (0,19): 1.4999999999999984,}
constant4_ij {(0,18): -0.4999999999999983, (0,19): 1.4999999999999982,}
constant7_ij { (0,18): -12491.630996921805, (0,19): 37474.892990765504,}
xaveragednegativeparticleconcentrationmolm3_i { 0.245049, 0.244694, 0.243985, 0.242921, 0.241503, 0.239730, 0.237603, 0.235121, 0.232284, 0.229093, 0.225547, 0.221647, 0.217392, 0.212783, 0.207819, 0.202500, 0.196827, 0.190799, 0.184417, 0.177680, }
xaveragedpositiveparticleconcentrationmolm3_i { 0.939986, 0.940066, 0.940228, 0.940471, 0.940795, 0.941200, 0.941685, 0.942252, 0.942899, 0.943628, 0.944437, 0.945328, 0.946299, 0.947351, 0.948485, 0.949699, 0.950994, 0.952370, 0.953827, 0.955365, }
varying2_i {(constant2_ij * xaveragedpositiveparticleconcentrationmolm3_j),}
varying3_i {(constant4_ij * xaveragedpositiveparticleconcentrationmolm3_j),}
varying4_i {(constant7_ij * xaveragednegativeparticleconcentrationmolm3_j),}
varying5_i {(constant3_ij * xaveragednegativeparticleconcentrationmolm3_j),}
r_i {(((0.05138515824298745 * arcsinh((-0.7999999999999998 / ((1.8973665961010275e-05 * pow(max(min(varying2_i, 51217.92521874824), 0.000512179257309275), 0.5)) * pow((51217.9257309275 - max(min(varying2_i, 51217.92521874824), 0.000512179257309275)), 0.5))))) + (((((((2.16216 + (0.07645 * tanh((30.834 - (57.858397200000006 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (2.1581 * tanh((52.294 - (53.412228 * max(min(varying3_i, 0.9999999999), 1e-10)))))) - (0.14169 * tanh((11.0923 - (21.0852666 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (0.2051 * tanh((1.4684 - (5.829105600000001 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (0.2531 * tanh((4.291641337386018 - (8.069908814589667 * max(min(varying3_i, 0.9999999999), 1e-10)))))) - (0.02167 * tanh((-87.5 + (177.0 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (1e-06 * ((1.0 / max(min(varying3_i, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(varying3_i, 0.9999999999), 1e-10))))))) - ((0.05138515824298745 * arcsinh((0.6666666666666666 / ((0.0006324555320336759 * pow(max(min(varying4_i, 24983.261744011077), 0.000249832619938437), 0.5)) * pow((24983.2619938437 - max(min(varying4_i, 24983.261744011077), 0.000249832619938437)), 0.5))))) + ((((((((((0.194 + (1.5 * exp((-120.0 * max(min(varying5_i, 0.9999999999), 1e-10))))) + (0.0351 * tanh((-3.44578313253012 + (12.048192771084336 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.0045 * tanh((-7.1344537815126055 + (8.403361344537815 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.035 * tanh((-18.466 + (20.0 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.0147 * tanh((-14.705882352941176 + (29.41176470588235 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.102 * tanh((-1.3661971830985917 + (7.042253521126761 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.022 * tanh((-54.8780487804878 + (60.975609756097555 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.011 * tanh((-5.486725663716814 + (44.24778761061947 * max(min(varying5_i, 0.9999999999), 1e-10)))))) + (0.0155 * tanh((-3.6206896551724133 + (34.48275862068965 * max(min(varying5_i, 0.9999999999), 1e-10)))))) + (1e-06 * ((1.0 / max(min(varying5_i, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(varying5_i, 0.9999999999), 1e-10)))))))),}
" expect "r" vec![3.191533267340602] ; f64,
pybamm_subexpression: "
constant2_ij { (0,18): -25608.96286546366, (0,19): 76826.88859639116,}
st_i { (0:20): xaveragednegativeparticleconcentrationmolm3 = 0.8000000000000016, (20:40): xaveragedpositiveparticleconcentrationmolm3 = 0.6000000000000001, }
varying2_i {(constant2_ij * xaveragedpositiveparticleconcentrationmolm3_j),}
" expect "varying2" vec![-25608.96286546366 * 0.6000000000000001 + 76826.88859639116 * 0.6000000000000001] ; f64,
pybamm_subexpression2: "
constant4_ij {(0,18): -0.4999999999999983, (0,19): 1.4999999999999982,}
st_i { (0:20): xaveragednegativeparticleconcentrationmolm3 = 0.8000000000000016, (20:40): xaveragedpositiveparticleconcentrationmolm3 = 0.6000000000000001, }
varying3_i {(constant4_ij * xaveragedpositiveparticleconcentrationmolm3_j),}
" expect "varying3" vec![-0.4999999999999983 * 0.6000000000000001 + 1.4999999999999982 * 0.6000000000000001] ; f64,
pybamm_subexpression3: "
constant7_ij { (0,18): -12491.630996921805, (0,19): 37474.892990765504,}
st_i { (0:20): xaveragednegativeparticleconcentrationmolm3 = 0.8000000000000016, (20:40): xaveragedpositiveparticleconcentrationmolm3 = 0.6000000000000001, }
varying4_i {(constant7_ij * xaveragednegativeparticleconcentrationmolm3_j),}
" expect "varying4" vec![-12491.630996921805 * 0.8000000000000016 + 37474.892990765504 * 0.8000000000000016] ; f64,
pybamm_subexpression4: "
varying2_i {30730.7554386,}
varying3_i {0.6,}
varying4_i {19986.6095951,}
varying5_i {0.8,}
r_i {(((0.05138515824298745 * arcsinh((-0.7999999999999998 / ((1.8973665961010275e-05 * pow(max(min(varying2_i, 51217.92521874824), 0.000512179257309275), 0.5)) * pow((51217.9257309275 - max(min(varying2_i, 51217.92521874824), 0.000512179257309275)), 0.5))))) + (((((((2.16216 + (0.07645 * tanh((30.834 - (57.858397200000006 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (2.1581 * tanh((52.294 - (53.412228 * max(min(varying3_i, 0.9999999999), 1e-10)))))) - (0.14169 * tanh((11.0923 - (21.0852666 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (0.2051 * tanh((1.4684 - (5.829105600000001 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (0.2531 * tanh((4.291641337386018 - (8.069908814589667 * max(min(varying3_i, 0.9999999999), 1e-10)))))) - (0.02167 * tanh((-87.5 + (177.0 * max(min(varying3_i, 0.9999999999), 1e-10)))))) + (1e-06 * ((1.0 / max(min(varying3_i, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(varying3_i, 0.9999999999), 1e-10))))))) - ((0.05138515824298745 * arcsinh((0.6666666666666666 / ((0.0006324555320336759 * pow(max(min(varying4_i, 24983.261744011077), 0.000249832619938437), 0.5)) * pow((24983.2619938437 - max(min(varying4_i, 24983.261744011077), 0.000249832619938437)), 0.5))))) + ((((((((((0.194 + (1.5 * exp((-120.0 * max(min(varying5_i, 0.9999999999), 1e-10))))) + (0.0351 * tanh((-3.44578313253012 + (12.048192771084336 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.0045 * tanh((-7.1344537815126055 + (8.403361344537815 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.035 * tanh((-18.466 + (20.0 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.0147 * tanh((-14.705882352941176 + (29.41176470588235 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.102 * tanh((-1.3661971830985917 + (7.042253521126761 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.022 * tanh((-54.8780487804878 + (60.975609756097555 * max(min(varying5_i, 0.9999999999), 1e-10)))))) - (0.011 * tanh((-5.486725663716814 + (44.24778761061947 * max(min(varying5_i, 0.9999999999), 1e-10)))))) + (0.0155 * tanh((-3.6206896551724133 + (34.48275862068965 * max(min(varying5_i, 0.9999999999), 1e-10)))))) + (1e-06 * ((1.0 / max(min(varying5_i, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(varying5_i, 0.9999999999), 1e-10)))))))),}
" expect "r" vec![(((0.05138515824298745 * f64::asinh(-0.7999999999999998 / ((1.897_366_596_101_027_5e-5 * f64::powf(f64::max(f64::min(30730.7554386, 51217.92521874824), 0.000512179257309275), 0.5)) * f64::powf(51217.9257309275 - f64::max(f64::min(30730.7554386, 51217.92521874824), 0.000512179257309275), 0.5)))) + (((((((2.16216 + (0.07645 * f64::tanh(30.834 - (57.858397200000006 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) + (2.1581 * f64::tanh(52.294 - (53.412228 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) - (0.14169 * f64::tanh(11.0923 - (21.0852666 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) + (0.2051 * f64::tanh(1.4684 - (5.829105600000001 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) + (0.2531 * f64::tanh(4.291641337386018 - (8.069908814589667 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) - (0.02167 * f64::tanh(-87.5 + (177.0 * f64::max(f64::min(0.6, 0.9999999999), 1e-10))))) + (1e-06 * ((1.0 / f64::max(f64::min(0.6, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + f64::max(f64::min(0.6, 0.9999999999), 1e-10))))))) - ((0.05138515824298745 * f64::asinh(0.6666666666666666 / ((0.0006324555320336759 * f64::powf(f64::max(f64::min(19986.6095951, 24983.261744011077), 0.000249832619938437), 0.5)) * f64::powf(24983.2619938437 - f64::max(f64::min(19986.6095951, 24983.261744011077), 0.000249832619938437), 0.5)))) + ((((((((((0.194 + (1.5 * f64::exp(-120.0 * f64::max(f64::min(0.8, 0.9999999999), 1e-10)))) + (0.0351 * f64::tanh(-3.44578313253012 + (12.048192771084336 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.0045 * f64::tanh(-7.1344537815126055 + (8.403361344537815 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.035 * f64::tanh(-18.466 + (20.0 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.0147 * f64::tanh(-14.705882352941176 + (29.41176470588235 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.102 * f64::tanh(-1.3661971830985917 + (7.042253521126761 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.022 * f64::tanh(-54.8780487804878 + (60.975609756097555 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) - (0.011 * f64::tanh(-5.486725663716814 + (44.24778761061947 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) + (0.0155 * f64::tanh(-3.6206896551724133 + (34.48275862068965 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))))) + (1e-06 * ((1.0 / f64::max(f64::min(0.8, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + f64::max(f64::min(0.8, 0.9999999999), 1e-10))))))))] ; f64,
pybamm_subexpression5: "r_i { (1.0 / max(min(0.6, 0.9999999999), 1e-10)),}" expect "r" vec![1.0 / f64::max(f64::min(0.6, 0.9999999999), 1e-10)] ; f64,
pybamm_subexpression6: "r_i { arcsinh(1.8973665961010275e-05), }" expect "r" vec![f64::asinh(1.897_366_596_101_027_5e-5)] ; f64,
pybamm_subexpression7: "r_i { (1.5 * exp(-120.0 * max(min(0.8, 0.9999999999), 1e-10))), }" expect "r" vec![1.5 * f64::exp(-120.0 * f64::max(f64::min(0.8, 0.9999999999), 1e-10))] ; f64,
pybamm_subexpression8: "r_i { (0.07645 * tanh(30.834 - (57.858397200000006 * max(min(0.6, 0.9999999999), 1e-10)))), }" expect "r" vec![0.07645 * f64::tanh(30.834 - (57.858397200000006 * f64::max(f64::min(0.6, 0.9999999999), 1e-10)))] ; f64,
pybamm_subexpression9: "r_i { (1e-06 * ((1.0 / max(min(0.8, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + max(min(0.8, 0.9999999999), 1e-10))))), }" expect "r" vec![1e-06 * ((1.0 / f64::max(f64::min(0.8, 0.9999999999), 1e-10)) + (1.0 / (-1.0 + f64::max(f64::min(0.8, 0.9999999999), 1e-10))))] ; f64,
pybamm_subexpression10: "r_i { (1.0 / (-1.0 + max(min(0.8, 0.9999999999), 1e-10))), }" expect "r" vec![1.0 / (-1.0 + f64::max(f64::min(0.8, 0.9999999999), 1e-10))] ; f64,
unary_negate_in_expr: "r_i { 1.0 / (-1.0 + 1.1) }" expect "r" vec![1.0 / (-1.0 + 1.1)] ; f64,
exp_sparse_vec: "a_i { (1): 1 } r_i { exp(a_i) }" expect "r" vec![f64::exp(0.0), f64::exp(1.0)] ; f64,
log_sparse_vec: "a_i { (1): 1 } r_i { log(a_i + 1) }" expect "r" vec![f64::ln(1.0), f64::ln(2.0)] ; f64,
}
tensor_test! {
scalar_vec_index_bug: "a_i { b = 1 } c_i { (0:2): 0, (2:3): 0.5 } d_i { (0:1): 2 } r_i { c_i * (d * b), c_i * (d_i * b) }" expect "r" vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0],
sparse_mat_mul12: "A_ij { (0,0):1, (0,1):1, (1,1):1, (2,2):1, (3,3):1, (4,4):1, (5,5):1 } b2_i { (1:6): 1 } r_i { A_ij * b2_j }" expect "r" vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
sparse_mat_mul8: "A_ij { (0,1): -0.5, (0,0): 1.5, (6,5): 1.5, (6,4): -0.5 } b_i { (0:6): 1 } r_i { A_ij * b_j } A1_ij { (0,1): 1, (0,2): 1, (1,1): 1, (2,2): 1, (3,3): 1, (4,4): 1, (5,5): 1, (6,6): 1 } b2_i { (0:7): 1 } b3_i { (0:7):2 } r2_i { A1_ij * (r_j + b2_j) * b3_j }" expect "r2" vec![4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0],
sparse_dense_concat: "a_i { (0): 1, (2): 3 } b_i { 4, 5 } r_i { a_i, b_i }" expect "r" vec![1., 3., 4., 5.],
sparse_mat_mul9: "A_ij { (0,1): -0.5, (0,0): 1.5, (1,1): 0, (2,2): 0, (3,3): 0, (4,4): 0, (5,5): 0, (6,5): 1.5, (6,4): -0.5 } b_i { (0:6): 1 } r_i { A_ij * b_j }" expect "r" vec![1.0, 1.0],
sparse_mat_mul10: "A_ij { (0,0): 1.5, (0,1): -0.5, (6,4): -0.5, (6,5): 1.5 } b_i { (0:6): 1 } r_i { A_ij * b_j }" expect "r" vec![1.0, 1.0],
sparse_mat_mul11: "A_ij { (0,0): 1.5, (0,1): -0.5, (1,1): 0, (2,2): 0, (3,3): 0, (4,4): 0, (5,5): 0, (6,4): -0.5, (6,5): 1.5 } b_i { (0:6): 1 } r_i { A_ij * b_j }" expect "r" vec![1.0, 1.0],
sparse_nonsquare_mat_vec_mul: "A_ij { (0, 0): 1, (0, 1): 4, (1, 2): 2 } b_j { (0:3): 5 } r_i { A_ij * b_j }" expect "r" vec![25.0, 10.0],
sparse_nonsquare_mat_vec_mul2: "A_ij { (0, 0): 1, (0, 1): 4, (1, 2): 2 } b_j { (2): 5 } r_i { A_ij * b_j } B_ij { (0..2,0..2): 1 } s_i { B_ij * max(r_j, 1) }" expect "s" vec![1.0, 10.0],
max_sparse_vec: "A_ij { (0..2,0..2): 1 } b_j { (1): 5 } r_i { A_ij * max(b_j, 1) }" expect "r" vec![1.0, 5.0],
row_vec_col_vec_mul: "a_ij { (0, 1): 1, (0, 2): 2, (0, 3): 3 } b_i { 4, 5, 6, 7 } r_i { a_ij * b_j }" expect "r" vec![5. + 12. + 21.],
max_sparse_scalar: "a_i { (0): 1, (2): 3 } r_i { max(a_i, 2) }" expect "r" vec![2., 2., 3.],
contract_to_mat_vec: "A_ij { (0, 0): 1, (1, 0): 3, (1, 1): 4 } B_ij { (1, 1): 2 } b_i { B_ij } r_i { A_ij * b_j }" expect "r" vec![8.],
sparse_mat_vec_mul7: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 1 } b_i { (2): 5 } r_i { A_ij * (1 + b_j) }" expect "r" vec![4., 12., 6.],
sparse_mat_vec_mul6: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 1 } b_i { (2): 5 } r_i { A_ij * (1 * b_j) }" expect "r" vec![10., 5.],
sparse_mat_vec_mul3: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 0 } b_i { (2): 5 } c_j { (0:3): 1 } r_i { A_ij * (b_j + c_j) }" expect "r" vec![4., 12.],
sparse_mat_vec_mul5: "A_ij { (1, 1): 2 } b_j { (1): 3 } r_i { A_ij * (1 * b_j) }" expect "r" vec![6.0],
sparse_mat_vec_mul4: "A_ij { (0, 1): 4, (1, 2): 2, (2, 2): 0 } b_i { (0): 2, (2): 5 } c_j { (0:3): 1 } r_i { A_ij * (b_j + c_j) }" expect "r" vec![4., 12.],
sparse_mat_vec_mul2: "A_ij { (0, 1): 4, (1, 0): 2 } b_i { (1): 5 } c_j { (0:1): 1, (1:2): 1 } r_i { A_ij * (b_j + c_j) }" expect "r" vec![24.0, 2.0],
sparse_mat_vec_mul: "A_ij { (1, 1): 2 } b_j { (1): 3 } r_i { A_ij * b_j }" expect "r" vec![6.0],
sparse_broadcast_to_sparse: "A_i { (1): 2 } B_ij { (0:2, 0:2): A_i }" expect "B" vec![2.0],
sparse_broadcast_to_sparse_add: "A_i { (1): 2 } C_ij { (1, 1): 1 } B_ij { (0:2, 0:2): A_i + C_ij }" expect "B" vec![2.0, 3.0],
sparse_contract_to_sparse: "A_ij { (1, 1): 2 } B_i { A_ij }" expect "B" vec![2.0],
diag_sparse_add: "A_ij { (0..2, 0..2): 1 } B_ij { (1, 1): 3 } R_ij { A_ij + B_ij }" expect "R" vec![1.0, 4.0],
diag_sparse_add2: "A_ij { (0..2, 0..2): 1 } B_ij { (1, 1): 3 } R_ij { A_ij + B_ji }" expect "R" vec![1.0, 4.0],
diag_sparse_mul: "A_ij { (0..2, 0..2): 1 } B_ij { (1, 1): 3 } R_ij { A_ij * B_ij }" expect "R" vec![3.0],
diag_dense_add: "A_ij { (0..2, 0..2): 1 } B_ij { (0:2, 0:2): 2 } R_ij { A_ij + B_ij }" expect "R" vec![3.0, 2.0, 2.0, 3.0],
diag_dense_mul: "A_ij { (0..2, 0..2): 2 } B_ij { (0:2, 0:2): 3 } R_ij { A_ij * B_ij }" expect "R" vec![6.0, 6.0],
sparse_sparse_mat_add2: "A_ij { (0, 0): 1, (1, 0): 5, (1, 1): 2 } B_ij { (0, 1): 3, (1, 1): 4 } R_ij { A_ij + B_ij }" expect "R" vec![1.0, 3.0, 5.0, 6.0],
sparse_sparse_mat_add3: "A_ij { (0, 0): 1, (1, 0): 5, (1, 1): 2 } B_ij { (0, 1): 3, (1, 1): 4 } R_ij { A_ij + B_ji }" expect "R" vec![1.0, 8.0, 6.0],
sparse_sparse_mat_add: "A_ij { (1, 1): 2 } B_ij { (0, 1): 3, (1, 1): 4 } R_ij { A_ij + B_ij }" expect "R" vec![3.0, 6.0],
sparse_dense_mat_add: "A_ij { (1, 1): 2 } B_ij { (0:2, 0:2): 3 } R_ij { A_ij + B_ij }" expect "R" vec![3.0, 3.0, 3.0, 5.0],
sparse_dense_mat_mul: "A_ij { (1, 1): 2 } B_ij { (0:2, 0:2): 3 } R_ij { A_ij * B_ij }" expect "R" vec![6.0],
sparse_dense_mat_mul2: "A_ij { (1, 0): 1, (1, 1): 2 } B_ij { (0:2, 0:2): 3 } R_ij { A_ij * B_ij }" expect "R" vec![3.0, 6.0],
diag_dense_mat_mul: "A_ij { (0, 0): 1, (1, 1): 2 } B_ij { (0:2, 0:2): 3 } R_ij { A_ij * B_ij }" expect "R" vec![3.0, 6.0],
sparse_sparse_vec_add: "a_i { (0): 1, (2): 2 } b_i { (2): 4 } r_i { a_i + b_i }" expect "r" vec![1.0, 6.0],
sparse_sparse_vec_add2: "a_i { (0): 1, (2): 2 } b_i { (1): 2, (2): 4 } r_i { a_i + b_i }" expect "r" vec![1.0, 2.0, 6.0],
sparse_sparse_vec_mul: "a_i { (0): 1, (2): 2 } b_i { (2): 4 } r_i { a_i * b_i }" expect "r" vec![8.0],
sparse_sparse_vec_div: "a_i { (2): 2 } b_i { (2): 4 } r_i { a_i / b_i }" expect "r" vec![0.5],
sparse_sparse_vec_mul2: "a_i { (0): 1, (2): 2 } b_i { (2): 4 } r_i { b_i * a_i }" expect "r" vec![8.0],
sparse_dense_vec_add: "a_i { (0): 1, (2): 2 } b_i { (0:3): 3 } r_i { a_i + b_i }" expect "r" vec![4.0, 3.0, 5.0],
sparse_dense_vec_add2: "a_i { (0): 1, (2): 2 } b_i { (0:3): 3 } r_i { b_i + a_i }" expect "r" vec![4.0, 3.0, 5.0],
sparse_dense_vec_mul: "a_i { (0): 1, (2): 2 } b_i { (0:3): 3 } r_i { b_i * a_i }" expect "r" vec![3.0, 6.0],
sparse_dense_vec_mul2: "a_i { (0): 1, (2): 2 } b_i { (0:3): 3 } r_i { a_i * b_i }" expect "r" vec![3.0, 6.0],
sparse_vec_vec_mul: "a_i { (0): 1, (2): 2 } b_i { (2): 4 } r_i { a_i * b_i }" expect "r" vec![8.0],
sparse_vec_vec_add: "a_i { (0): 1, (2): 2 } b_i { (2): 4 } r_i { a_i + b_i }" expect "r" vec![1.0, 6.0],
contraction_2d_to_vector: "a_ij { (0:3, 0:3): 1.0 } r_i { a_ij }" expect "r" vec![3.0, 3.0, 3.0],
col_vec: "a_ij { (0, 0): 1, (1, 0): 2 } b_i { (0:2): a_ij }" expect "b" vec![1.0, 2.0],
indexing1: "a_i { 0.0, 1.0, 2.0, 3.0 } r { a_i[2] }" expect "r" vec![2.0],
heaviside_function0: "r { heaviside(-0.1) }" expect "r" vec![0.0],
heaviside_function1: "r { heaviside(0.0) }" expect "r" vec![1.0],
abs_function: "r { abs(-2) }" expect "r" vec![f64::abs(-2.0)],
min_function: "r { min(2, 3) }" expect "r" vec![2.0],
max_function: "r { max(2, 3) }" expect "r" vec![3.0],
n_expression_scalar: "r { 2.0 * N + 1.0 }" expect "r" vec![1.0],
n_expression_vector: "r_i { 2.0 * N + 1.0, 3.0 * N + 2.0 }" expect "r" vec![1.0, 2.0],
scalar: "r {2}" expect "r" vec![2.0,],
constant: "r_i {2, 3}" expect "r" vec![2., 3.],
derived: "r_i {2, 3} k_i { 2 * r_i }" expect "k" vec![4., 6.],
concatenate: "r_i {2, 3} k_i { r_i, 2 * r_i }" expect "k" vec![2., 3., 4., 6.],
ones_matrix_dense: "I_ij { (0:2, 0:2): 1 }" expect "I" vec![1., 1., 1., 1.],
dense_matrix: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 }" expect "A" vec![1., 2., 3., 4.],
dense_vector: "x_i { (0:4): 1, (4:5): 2 }" expect "x" vec![1., 1., 1., 1., 2.],
identity_matrix_diagonal: "I_ij { (0..2, 0..2): 1 }" expect "I" vec![1., 1.],
concatenate_diagonal: "A_ij { (0..2, 0..2): 1 } B_ij { (0:2, 0:2): A_ij, (2:4, 2:4): A_ij }" expect "B" vec![1., 1., 1., 1.],
identity_matrix_sparse: "I_ij { (0, 0): 1, (1, 1): 2 }" expect "I" vec![1., 2.],
concatenate_sparse: "A_ij { (0, 0): 1, (1, 1): 2 } B_ij { (0:2, 0:2): A_ij, (2:4, 2:4): A_ij }" expect "B" vec![1., 2., 1., 2.],
sparse_rearrange: "A_ij { (0, 0): 1, (1, 1): 2, (0, 1): 3 }" expect "A" vec![1., 3., 2.],
sparse_rearrange2: "A_ij { (0, 1): 1, (1, 1): 2, (1, 0): 3, (2, 2): 4, (2, 1): 5 }" expect "A" vec![1., 3., 2., 5., 4.],
sparse_expression: "A_ij { (0, 0): 1, (0, 1): 2, (1, 1): 3 } B_ij { 2 * A_ij }" expect "B" vec![2., 4., 6.],
sparse_matrix_vect_multiply: "A_ij { (0, 0): 1, (1, 0): 2, (1, 1): 3 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![1., 8.],
sparse_rearrange_matrix_vect_multiply: "A_ij { (0, 1): 1, (1, 1): 2, (1, 0): 3, (2, 2): 4, (2, 1): 5 } x_i { 1, 2, 3 } b_i { A_ij * x_j }" expect "b" vec![2., 7., 22.],
diag_matrix_vect_multiply: "A_ij { (0, 0): 1, (1, 1): 3 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![1., 6.],
dense_matrix_vect_multiply: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![5., 11.],
sparse_matrix_vect_multiply_zero_row: "A_ij { (0, 1): 2 } x_i { 1, 2 } b_i { A_ij * x_j }" expect "b" vec![4.],
broadcast_vector1: "a_i { 1, 2 } b_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 } c_ij { b_ij * a_j }" expect "c" vec![1., 4., 3., 8.],
broadcast_vector2: "a_i { 1, 2 } b_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 } c_ij { a_j * b_ij }" expect "c" vec![1., 4., 3., 8.],
broadcast_vector3: "a_i { 1, 2 } b_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 } c_ij { b_ij * a_i }" expect "c" vec![1., 2., 6., 8.],
broadcast_vector4: "a_i { 1, 2 } b_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 } c_ij { a_i * b_ij }" expect "c" vec![1., 2., 6., 8.],
bidiagonal: "A_ij { (0..3, 0..3): 1, (1..3, 0..2): 2 }" expect "A" vec![1., 2., 1., 2., 1.],
}
macro_rules! tensor_grad_test {
($($name:ident: $text:literal expect $tensor_name:literal $expected_grad:expr ; $expected_rgrad:expr; $expected_sgrad:expr; $expected_srgrad:expr,)*) => {
$(
#[test]
fn $name() {
let full_text = format!("
in {{ p = 1 }}
u_i {{
y = p,
}}
dudt_i {{
dydt = p,
}}
{}
M_i {{
dydt,
}}
F_i {{
y,
}}
out_i {{
y,
}}
", $text);
let model = parse_ds_string(full_text.as_str()).unwrap();
#[allow(unused_variables)]
let discrete_model = match DiscreteModel::build("$name", &model) {
Ok(model) => model,
Err(e) => {
panic!("{}", e.as_error_message(full_text.as_str()));
}
};
#[cfg(feature = "rayon")]
{
#[cfg(feature = "llvm")]
{
use crate::execution::llvm::codegen::LlvmModule;
let results = tensor_test_common::<LlvmModule, f32>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice());
assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice());
assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice());
let results = tensor_test_common::<LlvmModule, f64>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice());
assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice());
assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice());
}
#[cfg(feature = "cranelift")]
{
let results = tensor_test_common::<crate::CraneliftJitModule, f32>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
let results = tensor_test_common::<crate::CraneliftJitModule, f64>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
}
}
#[cfg(feature = "llvm")]
{
use crate::execution::llvm::codegen::LlvmModule;
let results = tensor_test_common::<LlvmModule, f32>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice());
assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice());
assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice());
let results = tensor_test_common::<LlvmModule, f64>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice());
assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice());
assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice());
}
#[cfg(feature = "cranelift")]
{
let results = tensor_test_common::<crate::CraneliftJitModule, f32>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
let results = tensor_test_common::<crate::CraneliftJitModule, f64>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
}
}
)*
}
}
tensor_grad_test! {
const_grad: "r { 3 }" expect "r" vec![0.] ; vec![0.] ; vec![0.] ; vec![0.],
const_vec_grad: "r_i { 3, 4 }" expect "r" vec![0., 0.] ; vec![0.] ; vec![0., 0.] ; vec![0.],
input_grad: "r { 2 * p * p }" expect "r" vec![4.] ; vec![4.] ; vec![4.] ; vec![4.],
input_vec_grad: "r_i { 2 * p * p, 3 * p }" expect "r" vec![4., 3.] ; vec![7.] ; vec![4., 3.] ; vec![7.],
state_grad: "r { 2 * y }" expect "r" vec![2.] ; vec![2.] ; vec![0.] ; vec![0.],
input_and_state_grad: "r { 2 * y * p }" expect "r" vec![4.] ; vec![4.] ; vec![2.] ; vec![2.],
state_and_const_grad1: "r_i { 2 * y, 3 }" expect "r" vec![2., 0.] ; vec![2.] ; vec![0., 0.] ; vec![0.],
state_and_const_grad2: "r_i { 3 * y, 2 * y }" expect "r" vec![3., 2.] ; vec![5.] ; vec![0., 0.] ; vec![0.],
state_and_const_grad3: "r_i { 2 * p, 3 }" expect "r" vec![2., 0.] ; vec![2.] ; vec![2., 0.] ; vec![2.],
state_and_const_grad4: "r_i { 3 * p, 2 * p }" expect "r" vec![3., 2.] ; vec![5.] ; vec![3., 2.] ; vec![5.],
}
macro_rules! tensor_test_big_state {
($($name:ident: $text:literal expect $tensor_name:literal $expected_value:expr ; $expected_grad:expr ; $expected_rgrad:expr ; $expected_sgrad:expr ; $expected_srgrad:expr,)*) => {
$(
#[test]
fn $name() {
let full_text = format!("
in {{ p = 1 }}
u_i {{
(0:50): x = p,
(50:100): y = p,
}}
{}
F_i {{ x_i, y_i, }}
", $text);
let model = parse_ds_string(full_text.as_str()).unwrap();
#[allow(unused_variables)]
let discrete_model = match DiscreteModel::build("$name", &model) {
Ok(model) => model,
Err(e) => {
panic!("{}", e.as_error_message(full_text.as_str()));
}
};
#[cfg(feature = "llvm")]
{
use crate::execution::llvm::codegen::LlvmModule;
let results = tensor_test_common::<LlvmModule, f32>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice());
assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice());
//assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice());
//assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice());
let results = tensor_test_common::<LlvmModule, f64>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice());
assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice());
//assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice());
//assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice());
}
#[cfg(feature = "cranelift")]
{
let results = tensor_test_common::<crate::CraneliftJitModule, f32>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
let results = tensor_test_common::<crate::CraneliftJitModule, f64>(&discrete_model, $tensor_name, CompilerMode::SingleThreaded);
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
}
#[cfg(feature = "rayon")]
{
#[cfg(feature = "cranelift")]
{
let results = tensor_test_common::<crate::CraneliftJitModule, f32>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
let results = tensor_test_common::<crate::CraneliftJitModule, f64>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
}
// todo: multi-threaded llvm not working on macos
#[cfg(not(target_os = "macos"))]
#[cfg(feature = "llvm")]
{
use crate::execution::llvm::codegen::LlvmModule;
let results = tensor_test_common::<LlvmModule, f32>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice());
assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice());
//assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice());
//assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice());
let results = tensor_test_common::<LlvmModule, f64>(&discrete_model, $tensor_name, CompilerMode::MultiThreaded(None));
assert_relative_eq!(results[0].as_slice(), $expected_value.as_slice());
assert_relative_eq!(results[1].as_slice(), $expected_grad.as_slice());
assert_relative_eq!(results[2].as_slice(), $expected_rgrad.as_slice());
assert_relative_eq!(results[3].as_slice(), $expected_sgrad.as_slice());
//assert_relative_eq!(results[4].as_slice(), $expected_sgrad.as_slice());
assert_relative_eq!(results[5].as_slice(), $expected_srgrad.as_slice());
//assert_relative_eq!(results[6].as_slice(), $expected_srgrad.as_slice());
}
}
}
)*
}
}
tensor_test_big_state! {
big_state_expr: "r_i { x_i + y_i }" expect "r" vec![2.; 50] ; vec![2.; 50] ; vec![100.] ; vec![0.; 50] ; vec![0.],
big_state_multi: "r_i { x_i + y_i } b_i { x_i, r_i - y_i }" expect "b" vec![1.; 100] ; vec![1.; 100] ; vec![100.] ; vec![0.; 100] ; vec![0.],
big_state_multi_w_scalar: "r { 1.0 + 1.0 } b_i { x_i, r - y_i }" expect "b" vec![1.; 100] ; vec![1.; 50].into_iter().chain(vec![-1.; 50].into_iter()).collect::<Vec<_>>() ; vec![0.] ; vec![0.; 100] ; vec![0.],
big_state_diag: "b_ij { (0..100, 0..100): 3.0 } r_i { b_ij * u_j }" expect "r" vec![3.; 100] ; vec![3.; 100] ; vec![300.] ; vec![0.; 100] ; vec![0.],
big_state_tridiag: "b_ij { (0..100, 0..100): 3.0, (0..99, 1..100): 2.0, (1..100, 0..99): 1.0, (0, 99): 1.0, (99, 0): 2.0 } r_i { b_ij * u_j }" expect "r" vec![6.; 100]; vec![6.; 100]; vec![600.] ; vec![0.; 100]; vec![0.],
big_state_tridiag2: "b_ij { (0..100, 0..100): p + 2.0, (0..99, 1..100): 2.0, (1..100, 0..99): 1.0, (0, 99): 1.0, (99, 0): 2.0 } r_i { b_ij * u_j }" expect "r" vec![6.; 100]; vec![7.; 100]; vec![700.] ; vec![1.; 100]; vec![100.],
}
generate_tests!(test_repeated_grad_common);
#[allow(dead_code)]
fn test_repeated_grad_common<
M: CodegenModuleCompile + CodegenModuleJit,
T: Scalar + RelativeEq,
>() {
let full_text = "
in { p = 1 }
u_i {
y = p,
}
dudt_i {
dydt = 1,
}
r {
2 * y * p,
}
M_i {
dydt,
}
F_i {
r,
}
out_i {
y,
}
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = match DiscreteModel::build("test_repeated_grad", &model) {
Ok(model) => model,
Err(e) => {
panic!("{}", e.as_error_message(full_text));
}
};
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut u0 = vec![T::one()];
let mut du0 = vec![T::one()];
let mut res = vec![T::zero()];
let mut dres = vec![T::zero()];
let mut data = compiler.get_new_data();
let mut ddata = compiler.get_new_data();
let (_n_states, n_inputs, _n_outputs, _n_data, _n_stop, _has_mass, _has_reset) =
compiler.get_dims();
let inputs = vec![T::from_f64(2.).unwrap(); n_inputs];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0);
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
res.as_mut_slice(),
);
for _i in 0..3 {
let dinputs = vec![T::one(); n_inputs];
compiler.set_inputs_grad(
inputs.as_slice(),
dinputs.as_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
0,
);
compiler.set_u0_grad(
u0.as_mut_slice(),
du0.as_mut_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
);
compiler.rhs_grad(
T::zero(),
u0.as_slice(),
du0.as_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
res.as_mut_slice(),
dres.as_mut_slice(),
);
assert_relative_eq!(dres.as_slice(), vec![T::from_f64(8.).unwrap()].as_slice());
}
}
#[test]
fn test_constant_and_input_deps() {
let code = "
in { k = 1, }
l { k } m { l }
u { 1 }
F { u }
";
let model = parse_ds_string(code).unwrap();
let discrete_model = DiscreteModel::build("test_constant_and_input_deps", &model).unwrap();
assert!(discrete_model
.input_dep_defns()
.iter()
.any(|defn| defn.name() == "m"));
}
#[cfg(feature = "cranelift")]
#[test]
fn test_repeated_rhs_sparse_contraction_cranelift() {
test_repeated_rhs_sparse_contraction::<crate::CraneliftJitModule>();
}
#[cfg(feature = "llvm")]
#[test]
fn test_repeated_rhs_sparse_contraction_llvm() {
test_repeated_rhs_sparse_contraction::<crate::execution::llvm::codegen::LlvmModule>();
}
#[allow(dead_code)]
fn test_repeated_rhs_sparse_contraction<M: CodegenModuleCompile + CodegenModuleJit>() {
let code = "
A_ij { (0, 0): 1, (0, 1): 2, (3, 3): 3 }
u_i {
(0:4): x = 1,
}
zeros_i {
(0:4): 0,
}
c_i { A_ij * u_j + zeros_i }
F_i {
c_i,
}
";
let model = parse_ds_string(code).unwrap();
let discrete_model =
DiscreteModel::build("test_repeated_rhs_sparse_contraction", &model).unwrap();
let compiler = Compiler::<M, f64>::from_discrete_model(
&discrete_model,
Default::default(),
Some(code),
)
.unwrap();
let mut u = vec![0.0; 4];
let mut res = vec![0.0; 4];
let mut data = compiler.get_new_data();
let (_n_states, _n_inputs, _n_outputs, _n_data, _n_stop, _has_mass, _has_reset) =
compiler.get_dims();
compiler.set_u0(u.as_mut_slice(), data.as_mut_slice());
compiler.rhs(0.0, u.as_slice(), data.as_mut_slice(), res.as_mut_slice());
assert_relative_eq!(res.as_slice(), vec![3.0, 0.0, 0.0, 3.0].as_slice());
compiler.rhs(0.0, u.as_slice(), data.as_mut_slice(), res.as_mut_slice());
assert_relative_eq!(res.as_slice(), vec![3.0, 0.0, 0.0, 3.0].as_slice());
}
#[test]
fn test_additional_functions() {
let full_text = "
in { k = 1, }
u_i {
y = 1,
x = 2,
}
dudt_i {
dydt = 0,
0,
}
M_i {
dydt,
0,
}
F_i {
y - 1,
x - 2,
}
out_i {
y,
x,
2*x,
}
";
let model = parse_ds_string(full_text).unwrap();
#[allow(unused_variables)]
let discrete_model = DiscreteModel::build("$name", &model).unwrap();
#[cfg(feature = "cranelift")]
{
let compiler = Compiler::<crate::CraneliftJitModule, f64>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let (n_states, n_inputs, n_outputs, n_data, _n_stop, _has_mass, _has_reset) =
compiler.get_dims();
assert_eq!(n_states, 2);
assert_eq!(n_inputs, 1);
assert_eq!(n_outputs, 3);
assert_eq!(n_data, compiler.data_len());
let mut data = compiler.get_new_data();
let inputs = vec![1.1];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0);
let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap();
assert_relative_eq!(inputs, vec![1.1].as_slice());
let mut id = vec![0.0, 0.0];
compiler.set_id(id.as_mut_slice());
assert_eq!(id, vec![1.0, 0.0]);
let mut u = vec![0., 0.];
compiler.set_u0(u.as_mut_slice(), data.as_mut_slice());
assert_relative_eq!(u.as_slice(), vec![1., 2.].as_slice());
let mut rr = vec![1., 1.];
compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice());
assert_relative_eq!(rr.as_slice(), vec![0., 0.].as_slice());
let up = vec![2., 3.];
rr = vec![1., 1.];
compiler.mass(0., up.as_slice(), data.as_mut_slice(), rr.as_mut_slice());
assert_relative_eq!(rr.as_slice(), vec![2., 0.].as_slice());
let mut out = vec![0.; 3];
compiler.calc_out(0., u.as_slice(), data.as_mut_slice(), out.as_mut_slice());
assert_relative_eq!(out.as_slice(), vec![1., 2., 4.].as_slice());
}
}
#[test]
fn test_inputs() {
let full_text = "
in_i { a = 1, (1:3): b = 1, c = 1, }
u { y = 0 }
F { y }
out { y }
";
let model = parse_ds_string(full_text).unwrap();
#[allow(unused_variables)]
let discrete_model = DiscreteModel::build("test_inputs", &model).unwrap();
assert_eq!(discrete_model.input().unwrap().layout().nnz(), 4);
#[cfg(feature = "cranelift")]
{
let compiler = Compiler::<crate::CraneliftJitModule, f64>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut data = compiler.get_new_data();
let inputs = vec![1.0, 2.0, 3.0, 4.0];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0);
let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap();
assert_relative_eq!(inputs, vec![1.0, 2.0, 3.0, 4.0].as_slice());
}
#[cfg(feature = "llvm")]
{
let compiler = Compiler::<crate::LlvmModule, f64>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut data = compiler.get_new_data();
let inputs = vec![1.0, 2.0, 3.0, 4.0];
compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0);
let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap();
assert_relative_eq!(inputs, vec![1.0, 2.0, 3.0, 4.0].as_slice());
}
}
#[cfg(feature = "llvm")]
#[test]
fn test_mass_llvm() {
let full_text = "
dudt_i { dxdt = 1, dydt = 1, dzdt = 1 }
u_i { x = 1, y = 2, z = 3 }
F_i { x, y, z }
M_i { dxdt + dydt, dydt, dzdt }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("test_mass", &model).unwrap();
let compiler = Compiler::<crate::LlvmModule, f64>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut data = compiler.get_new_data();
let mut u0 = vec![0.0, 0.0, 0.0];
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
let mut mv = vec![0.0, 0.0, 0.0];
let mut v = vec![1.0, 1.0, 1.0];
compiler.mass(0.0, v.as_slice(), data.as_mut_slice(), mv.as_mut_slice());
assert_relative_eq!(mv.as_slice(), vec![2.0, 1.0, 1.0].as_slice());
mv = vec![1.0, 1.0, 1.0];
let mut ddata = compiler.get_new_data();
compiler.mass_rgrad(
0.0,
v.as_mut_slice(),
data.as_mut_slice(),
ddata.as_mut_slice(),
mv.as_mut_slice(),
);
assert_relative_eq!(v.as_slice(), vec![2.0, 3.0, 2.0].as_slice());
}
generate_tests!(test_send_sync);
#[allow(dead_code)]
fn test_send_sync<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
let full_text = "
u { y = 1 }
F { -y }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("test_sens_sync", &model).unwrap();
let compiler = Arc::new(
Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap(),
);
let compiler_clone = Arc::clone(&compiler);
let handle = thread::spawn(move || {
let mut data = compiler_clone.get_new_data();
let mut u0 = vec![T::zero()];
compiler_clone.set_u0(u0.as_mut_slice(), data.as_mut_slice());
let mut res = vec![T::zero()];
compiler_clone.rhs(
T::zero(),
u0.as_slice(),
data.as_mut_slice(),
res.as_mut_slice(),
);
assert_relative_eq!(res.as_slice(), vec![-T::one()].as_slice());
});
handle.join().unwrap();
}
#[cfg(feature = "llvm")]
#[test]
fn test_u0_sgrad_llvm() {
test_u0_sgrad::<crate::LlvmModule, f32>();
test_u0_sgrad::<crate::LlvmModule, f64>();
}
#[allow(dead_code)]
fn test_u0_sgrad<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq>() {
let full_text = "
in { a = 1, }
u { 2 * a * a }
F { -u }
";
let model = parse_ds_string(full_text).unwrap();
let discrete_model = DiscreteModel::build("test_u0_sgrad", &model).unwrap();
let compiler = Compiler::<M, T>::from_discrete_model(
&discrete_model,
Default::default(),
Some(full_text),
)
.unwrap();
let mut data = compiler.get_new_data();
let mut ddata = compiler.get_new_data();
let a = vec![T::from_f64(0.6).unwrap()];
let da = vec![T::one()];
compiler.set_inputs(a.as_slice(), data.as_mut_slice(), 0);
compiler.set_inputs_grad(
a.as_slice(),
da.as_slice(),
data.as_slice(),
ddata.as_mut_slice(),
0,
);
let mut u0 = vec![T::zero()];
let mut du0 = vec![T::zero()];
compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice());
compiler.set_u0_sgrad(
u0.as_mut_slice(),
du0.as_mut_slice(),
data.as_slice(),
ddata.as_mut_slice(),
);
assert_relative_eq!(
u0.as_slice(),
vec![T::from_f64(2.0).unwrap() * a[0] * a[0]].as_slice()
);
assert_relative_eq!(
du0.as_slice(),
vec![T::from_f64(4.0).unwrap() * a[0] * da[0]].as_slice()
);
}
#[cfg(feature = "llvm")]
#[test]
fn test_blah() {
let n = 10;
let full_text = format!(
"
a_ij {{
(0..{},1..{}): 1.0,
(0..{},0..{}): 1.0,
(1..{},0..{}): 1.0
}}
u_i {{
(0:{}): 1
}}
F_i {{
a_ij * u_j
}}
out_i {{
u_i
}}
",
n - 1,
n,
n,
n,
n,
n - 1,
n,
);
let model = parse_ds_string(&full_text).unwrap();
let discrete_model = DiscreteModel::build("blah", &model).unwrap();
Compiler::<crate::LlvmModule, f64>::from_discrete_model(
&discrete_model,
Default::default(),
Some(&full_text),
)
.unwrap();
}
}