symjit 2.14.9

a lightweight just-in-time (JIT) optimizer compiler
Documentation
use anyhow::{anyhow, Result};
use serde::Deserialize;
use std::io::{Read, Write};

use crate::builder::Builder;
use crate::config::Config;
use crate::expr::Expr;
use crate::node::Node;
use crate::utils::Storage;

pub trait Transformer {
    fn transform(&self, builder: &mut Builder) -> Result<Node>;
}

/// Collects the intermediate code (builder) and interface variables
#[derive(Debug, Clone)]
pub struct Program {
    pub builder: Builder,
    pub count_states: usize,
    pub count_params: usize,
    pub count_obs: usize,
    pub count_diffs: usize,
    pub count_loops: usize,
}

impl Program {
    const MAGIC: usize = 0xc2b244aefb8e4d5d;

    pub fn new(ml: &CellModel, config: Config) -> Result<Program> {
        /*
            this section lays the memory format
            the order of different sections is important!

            the layout is:

            +------------------------+
            | state variables        |
            +------------------------+
            | independent variable   | *
            +------------------------+
            | parameters             |
            +------------------------+
            | observables (output)   | **
            +------------------------+
            | differentials (output) |
            +------------------------+

            * => the independent variable slot is always allocated, even if not an ODE
            ** => => the first observable is the return value for fast functions
        */

        let mut builder = Builder::new(config);

        for v in &ml.states {
            builder.symbol_table().add_mem(&v.name);
        }

        // builder.symbol_table().add_mem(&ml.iv.name);

        for v in &ml.params {
            builder.symbol_table().add_param(&v.name);
        }

        let mut count_obs = 0;

        for eq in &ml.obs {
            if let Expr::Special = eq.lhs {
                // pass
            } else if let Some(name) = eq.lhs.normal_var() {
                if !builder.block().var_exists(&name) {
                    if name.starts_with("__") {
                        builder.block().create_tmp_named(&name);
                    } else {
                        builder.block().create_mem(&name);
                        count_obs += 1;
                    }
                }
            } else {
                return Err(anyhow!("lhs var not found"));
            }
        }

        for eq in &ml.odes {
            if let Some(name) = eq.lhs.diff_var() {
                let name = format!("δ{}", name);
                builder.symbol_table().add_mem(&name);
            } else {
                return Err(anyhow!("lhs diff var not found"));
            }
        }

        ml.transform(&mut builder)?;

        let k = if config.is_complex() { 2 } else { 1 };

        let count_loops = builder.count_loops;

        let prog = Program {
            builder,
            count_states: ml.states.len() * k,
            count_params: ml.params.len() * k,
            count_obs: count_obs * k,
            count_diffs: ml.odes.len() * k,
            count_loops,
        };

        Ok(prog)
    }

    pub fn config(&self) -> &Config {
        &self.builder.config
    }

    pub fn mem_size(&self) -> usize {
        self.count_states + self.count_obs + self.count_diffs + 1
    }

    pub fn clear(&mut self) {
        self.builder.block().clear();
    }
}

impl Storage for Program {
    fn save(&self, stream: &mut impl Write) -> Result<()> {
        stream.write_all(&Self::MAGIC.to_le_bytes())?;
        self.config().save(stream)?;
        stream.write_all(&self.count_states.to_le_bytes())?;
        stream.write_all(&self.count_params.to_le_bytes())?;
        stream.write_all(&self.count_obs.to_le_bytes())?;
        stream.write_all(&self.count_diffs.to_le_bytes())?;
        stream.write_all(&self.count_loops.to_le_bytes())?;
        Ok(())
    }

    fn load(stream: &mut impl Read) -> Result<Self> {
        let mut bytes: [u8; 8] = [0; 8];

        stream.read_exact(&mut bytes)?;

        if usize::from_le_bytes(bytes) != Self::MAGIC {
            return Err(anyhow!("invalid magic number"));
        }

        let config = Config::load(stream)?;

        stream.read_exact(&mut bytes)?;
        let count_states = usize::from_le_bytes(bytes);

        stream.read_exact(&mut bytes)?;
        let count_params = usize::from_le_bytes(bytes);

        stream.read_exact(&mut bytes)?;
        let count_obs = usize::from_le_bytes(bytes);

        stream.read_exact(&mut bytes)?;
        let count_diffs = usize::from_le_bytes(bytes);

        stream.read_exact(&mut bytes)?;
        let count_loops = usize::from_le_bytes(bytes);

        let builder = Builder::new(config);

        Ok(Program {
            builder,
            count_states,
            count_params,
            count_obs,
            count_diffs,
            count_loops,
        })
    }
}

/// A defined (state or param) variable
#[derive(Debug, Clone, Deserialize)]
pub struct Variable {
    pub name: String,
}

/// Transforms the input tree to the intermediate representation (tree-like)
impl Transformer for Variable {
    fn transform(&self, builder: &mut Builder) -> Result<Node> {
        builder.create_var(&self.name)
    }
}

// Expr tree
// #[derive(Debug, Clone, Deserialize)]
// #[serde(tag = "type")]
// pub enum Expr {
//     Tree { op: String, args: Vec<Expr> },
//     Const { val: f64 },
//     Var { name: String },
// }

impl Expr {
    /// Extracts the differentiated variable from the lhs of a diff eq
    pub fn diff_var(&self) -> Option<String> {
        if let Expr::Tree { args, op } = self {
            if op != "Differential" {
                return None;
            }
            if let Expr::Var { name } = &args[0] {
                return Some(name.clone());
            }
        };
        None
    }

    /// Extracts the regular variable from the lhs of an observable eq
    pub fn normal_var(&self) -> Option<String> {
        if let Expr::Var { name } = self {
            return Some(name.clone());
        };
        None
    }

    //**************** Transformations *****************//

    fn transform_unary(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
        let x = args[0].transform(builder)?;
        builder.add_unary(op, x)
    }

    fn transform_binary(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
        let l = args[0].transform(builder)?;
        let r = args[1].transform(builder)?;
        builder.add_binary(op, l, r)
    }

    /// Ternary operator is the conditional select operator
    fn transform_ternary(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
        if op != "ifelse" {
            return self.transform_poly(builder, op, args);
        }

        let cond = args[0].transform(builder)?;
        let true_val = args[1].transform(builder)?;
        let false_val = args[2].transform(builder)?;

        builder.create_ifelse(cond, true_val, false_val)
    }

    /// Addition and Multiplication can haev multiple arguments
    /// The intermediate tree has only unary and binary nodes
    fn transform_poly(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
        if op == "Sum" || op == "Product" {
            return self.transform_loop(builder, op, args);
        }

        if !(op == "plus" || op == "times" || op == "min" || op == "max") {
            return Err(anyhow!("missing poly op: {}", op));
        }

        let n = args.len();

        if n == 1 {
            let x = args[0].transform(builder)?;
            Ok(x)
        } else if n == 2 {
            let x = args[0].transform(builder)?;
            let y = args[1].transform(builder)?;
            let z = builder.create_binary(op, x, y)?;
            Ok(z)
        } else {
            let x = self.transform_poly(builder, op, &args[..n >> 1])?;
            let y = self.transform_poly(builder, op, &args[n >> 1..])?;
            let z = builder.create_binary(op, x, y)?;
            Ok(z)
        }

        // for arg in args.iter().skip(1) {
        //     let y = arg.transform(builder)?;
        //     x = builder.create_binary(op, x, y)?;
        // }

        // Ok(x)
    }

    fn transform_loop(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
        let var = builder
            .block()
            .create_tmp_named(&args[1].normal_var().unwrap());
        let start = args[2].transform(builder)?;
        let (accum_var, loop_id) = builder.add_loop_prefix(op, var.clone(), start)?;
        let eq = args[0].transform(builder)?;
        let end = args[3].transform(builder)?;
        builder.add_loop_body(op, eq, var, end, accum_var, loop_id)
    }
}

impl Transformer for Expr {
    fn transform(&self, builder: &mut Builder) -> Result<Node> {
        let dst = match self {
            Expr::Const { val } => builder.create_const(*val)?,
            Expr::Var { name } => builder.create_var(name)?,
            Expr::Tree { op, args } => match args.len() {
                1 => self.transform_unary(builder, op.as_str(), args)?,
                2 => self.transform_binary(builder, op.as_str(), args)?,
                3 => self.transform_ternary(builder, op.as_str(), args)?,
                _ => self.transform_poly(builder, op.as_str(), args)?,
            },
            _ => builder.create_void()?, // should be handled by Equation::special
        };
        Ok(dst)
    }
}

/// Represents lhs ~ rhs
#[derive(Debug, Clone, Deserialize)]
pub struct Equation {
    pub lhs: Expr,
    pub rhs: Expr,
}

impl Transformer for Equation {
    fn transform(&self, builder: &mut Builder) -> Result<Node> {
        if let Expr::Special = self.lhs {
            return self.special(builder);
        }

        let var = if let Some(var) = self.lhs.diff_var() {
            format!("δ{}", var)
        } else if let Some(var) = self.lhs.normal_var() {
            var
        } else {
            return Err(anyhow!("lhs should be a variable"));
        };

        let rhs = self.rhs.transform(builder)?;
        let lhs = builder.create_var(var.as_str())?;
        builder.add_assign(lhs, rhs)?;
        builder.create_void()
    }
}

impl Equation {
    fn special(&self, builder: &mut Builder) -> Result<Node> {
        match &self.rhs {
            Expr::Label { id } => {
                let label = format!("L.{}", id);
                builder.block().add_label(&label);
            }
            Expr::Branch { id } => {
                let label = format!("L.{}", id);
                builder.block().add_branch(&label);
            }
            Expr::BranchIf { cond, id, is_else } => {
                let cond = cond.transform(builder)?;
                let label = format!("L.{}", id);
                builder.block().add_branch_if(cond, &label, *is_else);
            }
            _ => return Err(anyhow!("Special expressions are Label and IfElse")),
        }

        builder.create_void()
    }
}

/// Loads a model from a JSON file
/// Historically from a CellML source; hence the name.
#[derive(Debug, Clone, Deserialize)]
pub struct CellModel {
    pub iv: Variable,
    pub params: Vec<Variable>,
    pub states: Vec<Variable>,
    #[allow(dead_code)]
    pub algs: Vec<Equation>,
    pub odes: Vec<Equation>,
    pub obs: Vec<Equation>,
}

impl CellModel {
    pub fn load(text: &str) -> Result<CellModel> {
        Ok(serde_json::from_str(text)?)
    }
}

impl Transformer for CellModel {
    fn transform(&self, builder: &mut Builder) -> Result<Node> {
        for eq in &self.obs {
            eq.transform(builder)?;
        }

        for eq in &self.odes {
            eq.transform(builder)?;
        }

        builder.create_void()
    }
}