use anyhow::{anyhow, Result};
use serde::Deserialize;
use std::io::{Read, Write};
use crate::builder::Builder;
use crate::config::Config;
use crate::expr::Expr;
use crate::node::Node;
use crate::utils::Storage;
pub trait Transformer {
fn transform(&self, builder: &mut Builder) -> Result<Node>;
}
#[derive(Debug, Clone)]
pub struct Program {
pub builder: Builder,
pub count_states: usize,
pub count_params: usize,
pub count_obs: usize,
pub count_diffs: usize,
pub count_loops: usize,
}
impl Program {
const MAGIC: usize = 0xc2b244aefb8e4d5d;
pub fn new(ml: &CellModel, config: Config) -> Result<Program> {
let mut builder = Builder::new(config.clone());
for v in &ml.states {
if v.name.starts_with("__") {
builder.block().create_tmp_named(&v.name);
} else {
builder.block().create_mem(&v.name);
}
}
for v in &ml.params {
builder.symbol_table().add_param(&v.name);
}
let mut count_obs = 0;
for eq in &ml.obs {
if let Expr::Special = eq.lhs {
} else if let Some(name) = eq.lhs.normal_var() {
if !builder.block().var_exists(&name) {
if name.starts_with("__") {
builder.block().create_tmp_named(&name);
} else {
builder.block().create_mem(&name);
count_obs += 1;
}
}
} else {
return Err(anyhow!("lhs var not found"));
}
}
for eq in &ml.odes {
if let Some(name) = eq.lhs.diff_var() {
let name = format!("δ{}", name);
builder.symbol_table().add_mem(&name);
} else {
return Err(anyhow!("lhs diff var not found"));
}
}
ml.transform(&mut builder)?;
let k = if config.is_complex() { 2 } else { 1 };
let count_loops = builder.count_loops;
let prog = Program {
builder,
count_states: ml.states.len() * k,
count_params: ml.params.len() * k,
count_obs: count_obs * k,
count_diffs: ml.odes.len() * k,
count_loops,
};
Ok(prog)
}
pub fn config(&self) -> &Config {
&self.builder.config
}
pub fn mem_size(&self) -> usize {
self.count_states + self.count_obs + self.count_diffs + 1
}
pub fn clear(&mut self) {
self.builder.block().clear();
}
}
impl Storage for Program {
fn save(&self, stream: &mut impl Write) -> Result<()> {
stream.write_all(&Self::MAGIC.to_le_bytes())?;
self.config().save(stream)?;
stream.write_all(&self.count_states.to_le_bytes())?;
stream.write_all(&self.count_params.to_le_bytes())?;
stream.write_all(&self.count_obs.to_le_bytes())?;
stream.write_all(&self.count_diffs.to_le_bytes())?;
stream.write_all(&self.count_loops.to_le_bytes())?;
self.builder.save(stream)?;
Ok(())
}
fn load(stream: &mut impl Read, config: &Config) -> Result<Self> {
let mut bytes: [u8; 8] = [0; 8];
stream.read_exact(&mut bytes)?;
if usize::from_le_bytes(bytes) != Self::MAGIC {
return Err(anyhow!("invalid magic number (Program)"));
}
let config = Config::load(stream, config)?;
stream.read_exact(&mut bytes)?;
let count_states = usize::from_le_bytes(bytes);
stream.read_exact(&mut bytes)?;
let count_params = usize::from_le_bytes(bytes);
stream.read_exact(&mut bytes)?;
let count_obs = usize::from_le_bytes(bytes);
stream.read_exact(&mut bytes)?;
let count_diffs = usize::from_le_bytes(bytes);
stream.read_exact(&mut bytes)?;
let count_loops = usize::from_le_bytes(bytes);
let builder = Builder::load(stream, &config)?;
Ok(Program {
builder,
count_states,
count_params,
count_obs,
count_diffs,
count_loops,
})
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct Variable {
pub name: String,
}
impl Transformer for Variable {
fn transform(&self, builder: &mut Builder) -> Result<Node> {
builder.create_var(&self.name)
}
}
impl Expr {
pub fn diff_var(&self) -> Option<String> {
if let Expr::Tree { args, op } = self {
if op != "Differential" {
return None;
}
if let Expr::Var { name } = &args[0] {
return Some(name.clone());
}
};
None
}
pub fn normal_var(&self) -> Option<String> {
if let Expr::Var { name } = self {
return Some(name.clone());
};
None
}
fn transform_unary(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
let x = args[0].transform(builder)?;
builder.add_unary(op, x)
}
fn transform_binary(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
let l = args[0].transform(builder)?;
let r = args[1].transform(builder)?;
builder.add_binary(op, l, r)
}
fn transform_ternary(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
if op != "ifelse" {
return self.transform_poly(builder, op, args);
}
let cond = args[0].transform(builder)?;
let true_val = args[1].transform(builder)?;
let false_val = args[2].transform(builder)?;
builder.create_ifelse(cond, true_val, false_val)
}
fn transform_poly(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
if op == "Sum" || op == "Product" {
return self.transform_loop(builder, op, args);
}
if !(op == "plus" || op == "times" || op == "min" || op == "max") {
return Err(anyhow!("missing poly op: {}", op));
}
let n = args.len();
if n == 1 {
let x = args[0].transform(builder)?;
Ok(x)
} else if n == 2 {
let x = args[0].transform(builder)?;
let y = args[1].transform(builder)?;
let z = builder.create_binary(op, x, y)?;
Ok(z)
} else {
let x = self.transform_poly(builder, op, &args[..n >> 1])?;
let y = self.transform_poly(builder, op, &args[n >> 1..])?;
let z = builder.create_binary(op, x, y)?;
Ok(z)
}
}
fn transform_loop(&self, builder: &mut Builder, op: &str, args: &[Expr]) -> Result<Node> {
let var = builder
.block()
.create_tmp_named(&args[1].normal_var().unwrap());
let start = args[2].transform(builder)?;
let (accum_var, loop_id) = builder.add_loop_prefix(op, var.clone(), start)?;
let eq = args[0].transform(builder)?;
let end = args[3].transform(builder)?;
builder.add_loop_body(op, eq, var, end, accum_var, loop_id)
}
}
impl Transformer for Expr {
fn transform(&self, builder: &mut Builder) -> Result<Node> {
let dst = match self {
Expr::Const { val } => builder.create_const(*val)?,
Expr::Var { name } => builder.create_var(name)?,
Expr::Tree { op, args } => match args.len() {
1 => self.transform_unary(builder, op.as_str(), args)?,
2 => self.transform_binary(builder, op.as_str(), args)?,
3 => self.transform_ternary(builder, op.as_str(), args)?,
_ => self.transform_poly(builder, op.as_str(), args)?,
},
_ => builder.create_void()?, };
Ok(dst)
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct Equation {
pub lhs: Expr,
pub rhs: Expr,
}
impl Transformer for Equation {
fn transform(&self, builder: &mut Builder) -> Result<Node> {
if let Expr::Special = self.lhs {
return self.special(builder);
}
let var = if let Some(var) = self.lhs.diff_var() {
format!("δ{}", var)
} else if let Some(var) = self.lhs.normal_var() {
var
} else {
return Err(anyhow!("lhs should be a variable"));
};
let rhs = self.rhs.transform(builder)?;
let lhs = builder.create_var(var.as_str())?;
builder.add_assign(lhs, rhs)?;
builder.create_void()
}
}
impl Equation {
fn special(&self, builder: &mut Builder) -> Result<Node> {
match &self.rhs {
Expr::Label { id } => {
let label = format!("L.{}", id);
builder.block().add_label(&label);
}
Expr::Branch { id } => {
let label = format!("L.{}", id);
builder.block().add_branch(&label);
}
Expr::BranchIf { cond, id, is_else } => {
let cond = cond.transform(builder)?;
let label = format!("L.{}", id);
builder.block().add_branch_if(cond, &label, *is_else);
}
_ => return Err(anyhow!("Special expressions are Label and IfElse")),
}
builder.create_void()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct CellModel {
pub iv: Variable,
pub params: Vec<Variable>,
pub states: Vec<Variable>,
#[allow(dead_code)]
pub algs: Vec<Equation>,
pub odes: Vec<Equation>,
pub obs: Vec<Equation>,
}
impl CellModel {
pub fn new() -> CellModel {
CellModel {
iv: Expr::var("$_").to_variable().unwrap(),
params: Vec::new(),
states: Vec::new(),
algs: Vec::new(),
odes: Vec::new(),
obs: Vec::new(),
}
}
pub fn load(text: &str) -> Result<CellModel> {
Ok(serde_json::from_str(text)?)
}
}
impl Transformer for CellModel {
fn transform(&self, builder: &mut Builder) -> Result<Node> {
for eq in &self.obs {
eq.transform(builder)?;
}
for eq in &self.odes {
eq.transform(builder)?;
}
builder.create_void()
}
}