use std::collections::HashMap;
use log::{debug, log_enabled, Level};
use ndarray::s;
use crate::{
ast::{self, Ast, AstKind, StringSpan},
discretise::layout::NonZero,
};
use super::{
can_broadcast_to, layout::ArcLayout, Layout, LayoutKind, Shape, Tensor, TensorBlock,
ValidationError, ValidationErrors,
};
pub struct EnvVar {
layout: ArcLayout,
is_time_dependent: bool,
is_state_dependent: bool,
is_dstatedt_dependent: bool,
is_input_dependent: bool,
is_model_dependent: bool,
is_algebraic: bool,
}
impl EnvVar {
pub fn is_time_dependent(&self) -> bool {
self.is_time_dependent
}
pub fn is_state_dependent(&self) -> bool {
self.is_state_dependent
}
pub fn is_dstatedt_dependent(&self) -> bool {
self.is_dstatedt_dependent
}
pub fn is_algebraic(&self) -> bool {
self.is_algebraic
}
pub fn is_input_dependent(&self) -> bool {
self.is_input_dependent
}
pub fn is_model_dependent(&self) -> bool {
self.is_model_dependent
}
pub fn layout(&self) -> &Layout {
self.layout.as_ref()
}
}
pub struct Env {
current_span: Option<StringSpan>,
errs: ValidationErrors,
vars: HashMap<String, EnvVar>,
pub(crate) state0_input_deps: Vec<NonZero>,
pub(crate) dstate0_input_deps: Vec<NonZero>,
}
impl Env {
fn eval_const_integer_expr(expr: &Ast) -> Option<i64> {
match &expr.kind {
AstKind::Integer(v) => Some(*v),
AstKind::Number(v) => {
if v.fract() == 0.0 {
Some(*v as i64)
} else {
None
}
}
AstKind::Monop(op) => {
let child = Self::eval_const_integer_expr(op.child.as_ref())?;
match op.op {
'+' => Some(child),
'-' => child.checked_neg(),
_ => None,
}
}
AstKind::Binop(op) => {
let left = Self::eval_const_integer_expr(op.left.as_ref())?;
let right = Self::eval_const_integer_expr(op.right.as_ref())?;
match op.op {
'+' => left.checked_add(right),
'-' => left.checked_sub(right),
'*' => left.checked_mul(right),
'/' => {
if right == 0 {
None
} else {
Some(left / right)
}
}
'%' => {
if right == 0 {
None
} else {
Some(left % right)
}
}
_ => None,
}
}
_ => None,
}
}
fn eval_integer_expr_with_n(expr: &Ast, n: i64) -> Option<i64> {
match &expr.kind {
AstKind::Integer(v) => Some(*v),
AstKind::Number(v) => {
if v.fract() == 0.0 {
Some(*v as i64)
} else {
None
}
}
AstKind::Name(name) => {
if name.name == "N" {
Some(n)
} else {
None
}
}
AstKind::Monop(op) => {
let child = Self::eval_integer_expr_with_n(op.child.as_ref(), n)?;
match op.op {
'+' => Some(child),
'-' => child.checked_neg(),
_ => None,
}
}
AstKind::Binop(op) => {
let left = Self::eval_integer_expr_with_n(op.left.as_ref(), n)?;
let right = Self::eval_integer_expr_with_n(op.right.as_ref(), n)?;
match op.op {
'+' => left.checked_add(right),
'-' => left.checked_sub(right),
'*' => left.checked_mul(right),
'/' => {
if right == 0 {
None
} else {
Some(left / right)
}
}
'%' => {
if right == 0 {
None
} else {
Some(left % right)
}
}
_ => None,
}
}
_ => None,
}
}
fn eval_constant_range_width(start: &Ast, end: &Ast) -> Option<i64> {
if let (Some(first), Some(last)) = (
Self::eval_const_integer_expr(start),
Self::eval_const_integer_expr(end),
) {
return last.checked_sub(first);
}
let mut width = None;
for n in [0_i64, 1, 2, 3, 7, 16] {
let first_n = Self::eval_integer_expr_with_n(start, n)?;
let last_n = Self::eval_integer_expr_with_n(end, n)?;
let width_n = last_n.checked_sub(first_n)?;
match width {
Some(prev) if prev != width_n => return None,
Some(_) => {}
None => width = Some(width_n),
}
}
width
}
pub fn new() -> Self {
let mut vars = HashMap::new();
vars.insert(
"t".to_string(),
EnvVar {
layout: ArcLayout::new(Layout::new_scalar()),
is_time_dependent: true,
is_state_dependent: false,
is_dstatedt_dependent: false,
is_input_dependent: false,
is_model_dependent: false,
is_algebraic: true,
},
);
vars.insert(
"N".to_string(),
EnvVar {
layout: ArcLayout::new(Layout::new_scalar()),
is_time_dependent: false,
is_state_dependent: false,
is_dstatedt_dependent: false,
is_input_dependent: false,
is_model_dependent: true,
is_algebraic: true,
},
);
Env {
errs: ValidationErrors::default(),
vars,
current_span: None,
state0_input_deps: vec![],
dstate0_input_deps: vec![],
}
}
pub fn new_layout_ptr(&mut self, layout: Layout) -> ArcLayout {
for var in self.vars.values() {
if var.layout.as_ref().eq_nonzeros_and_deps(&layout) {
return var.layout.clone();
}
}
ArcLayout::new(layout)
}
pub fn is_tensor_time_dependent(&self, tensor: &Tensor) -> bool {
if tensor.name() == "u" || tensor.name() == "dudt" {
return true;
};
tensor.elmts().iter().any(|block| {
block
.expr()
.get_dependents()
.iter()
.any(|&dep| dep == "t" || self.vars[dep].is_time_dependent())
})
}
pub fn is_tensor_state_dependent(&self, tensor: &Tensor) -> bool {
self.is_tensor_dependent_on(tensor, "u")
}
pub fn is_tensor_input_dependent(&self, tensor: &Tensor) -> bool {
self.is_tensor_dependent_on(tensor, "in")
}
pub fn is_tensor_model_dependent(&self, tensor: &Tensor) -> bool {
if tensor.name() == "N" {
return true;
}
tensor.elmts().iter().any(|block| {
block
.expr()
.get_dependents()
.iter()
.any(|&dep| dep == "N" || self.vars[dep].is_model_dependent())
})
}
pub fn is_tensor_dstatedt_dependent(&self, tensor: &Tensor) -> bool {
self.is_tensor_dependent_on(tensor, "dudt")
}
fn is_tensor_dependent_on(&self, tensor: &Tensor, var: &str) -> bool {
if tensor.name() == var {
return true;
};
tensor.elmts().iter().any(|block| {
block.expr().get_dependents().iter().any(|&dep| {
dep == var
|| match var {
"u" => self.vars[dep].is_state_dependent(),
"dudt" => self.vars[dep].is_dstatedt_dependent(),
"in" => self.vars[dep].is_input_dependent(),
_ => unreachable!(),
}
})
})
}
pub fn push_var(&mut self, var: &Tensor) {
self.vars.insert(
var.name().to_string(),
EnvVar {
layout: var.layout_ptr().clone(),
is_algebraic: true,
is_time_dependent: self.is_tensor_time_dependent(var),
is_state_dependent: self.is_tensor_state_dependent(var),
is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var),
is_input_dependent: self.is_tensor_input_dependent(var),
is_model_dependent: self.is_tensor_model_dependent(var),
},
);
}
pub fn push_var_blk(&mut self, var: &Tensor, var_blk: &TensorBlock) {
self.vars.insert(
var_blk.name().unwrap().to_string(),
EnvVar {
layout: var_blk.layout_ptr().clone(),
is_algebraic: true,
is_time_dependent: self.is_tensor_time_dependent(var),
is_state_dependent: self.is_tensor_state_dependent(var),
is_dstatedt_dependent: self.is_tensor_dstatedt_dependent(var),
is_input_dependent: self.is_tensor_input_dependent(var),
is_model_dependent: self.is_tensor_model_dependent(var),
},
);
}
pub fn get(&self, name: &str) -> Option<&EnvVar> {
self.vars.get(name)
}
fn get_layout_binary_op<'s>(
&mut self,
left: &Ast<'s>,
right: &Ast<'s>,
op: &ast::Binop,
indices: &Vec<char>,
) -> Option<Layout> {
let left_layout = self.get_layout(left, indices)?;
let right_layout = self.get_layout(right, indices)?;
match Layout::broadcast(vec![left_layout, right_layout], Some(op.op)) {
Ok(layout) => Some(layout),
Err(e) => {
self.errs.push(ValidationError::new(
format!("{}. Op is {}, lhs is {}, rhs is {}.", e, op.op, left, right),
left.span,
));
None
}
}
}
fn get_layout_name(
&mut self,
name: &str,
ast: &Ast,
rhs_indices: &[char],
lhs_indices: &[char],
indice: Option<&Ast>,
) -> Option<Layout> {
let var = self.get(name);
if var.is_none() {
self.errs.push(ValidationError::new(
format!("cannot find variable {name}"),
ast.span,
));
return None;
}
let var = var.unwrap();
let layout = var.layout();
if rhs_indices.len() < layout.min_rank() {
self.errs.push(ValidationError::new(
format!(
"cannot index variable {} with {} indices. Expected at least {} indices",
name,
rhs_indices.len(),
layout.rank()
),
ast.span,
));
return None;
}
let mut permutation = vec![0; rhs_indices.len()];
for i in 0..rhs_indices.len() {
permutation[i] = match lhs_indices.iter().position(|&x| x == rhs_indices[i]) {
Some(pos) => pos,
None => {
let mut allow_missing = false;
if let Some(indice) = indice {
let indice = indice.kind.as_indice().unwrap();
if indice.sep.is_none() || indice.last.is_none() {
allow_missing = true;
}
};
if !allow_missing {
self.errs.push(ValidationError::new(
format!(
"cannot find index {} in lhs indices {:?} ",
rhs_indices[i], lhs_indices
),
ast.span,
));
return None;
}
0
}
}
}
let layout_permuted = match layout.permute(permutation.as_slice()) {
Ok(layout) => layout,
Err(e) => {
self.errs
.push(ValidationError::new(format!("{e}"), ast.span));
return None;
}
};
if let Some(indice) = indice {
let indice = indice.kind.as_indice().unwrap();
let is_one_d = layout_permuted.shape().iter().filter(|&&d| d != 1).count() == 1;
if !is_one_d || layout_permuted.kind() != &LayoutKind::Dense {
self.errs.push(ValidationError::new(
format!(
"can only index dense 1D variables. Variable {} has layout {}",
name, layout_permuted
),
ast.span,
));
return None;
}
if indice.sep.is_some() && indice.last.is_none() {
self.errs.push(ValidationError::new(
"range indice must have an end value".to_string(),
ast.span,
));
return None;
}
if indice.sep.is_none() {
let mut new_layout = Layout::new_scalar();
let (first, last) =
if let Some(first) = Self::eval_const_integer_expr(indice.first.as_ref()) {
(first, first + 1)
} else {
let axis = layout_permuted
.shape()
.iter()
.position(|&d| d != 1)
.unwrap_or(0);
let dim = *layout_permuted.shape().get(axis).unwrap_or(&1);
(0, i64::try_from(dim).unwrap())
};
new_layout.filter_deps_from(layout_permuted, first, last);
return Some(new_layout);
} else {
let end_expr = indice.last.as_ref().unwrap().as_ref();
let Some(width) = Self::eval_constant_range_width(indice.first.as_ref(), end_expr)
else {
self.errs.push(ValidationError::new(
"range indice width must be an integer constant (independent of N)"
.to_string(),
ast.span,
));
return None;
};
if width < 0 {
self.errs.push(ValidationError::new(
format!("invalid range indice: width {} is negative", width),
ast.span,
));
return None;
}
let dim = usize::try_from(width).unwrap();
let shape = layout_permuted
.shape()
.map(|&d| if d != 1 { dim } else { 1 });
let mut new_layout = Layout::new_dense(Shape::from(shape));
let first_const = Self::eval_const_integer_expr(indice.first.as_ref());
if let Some(first) = first_const {
let last = first + width;
new_layout.filter_deps_from(layout_permuted, first, last);
} else if dim != 0 {
let axis = layout_permuted
.shape()
.iter()
.position(|&d| d != 1)
.unwrap_or(0);
let source_dim =
i64::try_from(*layout_permuted.shape().get(axis).unwrap_or(&1)).unwrap();
let max_start = source_dim.saturating_sub(width);
let mut merged_layout: Option<Layout> = None;
for start in 0..=max_start {
let mut window_layout = Layout::new_dense(new_layout.shape().clone());
window_layout.filter_deps_from(
layout_permuted.clone(),
start,
start + width,
);
merged_layout = Some(match merged_layout {
Some(accumulated) => accumulated.union(window_layout),
None => window_layout,
});
}
if let Some(merged_layout) = merged_layout {
new_layout = merged_layout;
}
}
return Some(new_layout);
}
}
Some(layout_permuted)
}
fn get_layout_call(
&mut self,
call: &ast::Call,
ast: &Ast,
indices: &Vec<char>,
) -> Option<Layout> {
let layouts = call
.args
.iter()
.map(|c| self.get_layout(c, indices))
.collect::<Option<Vec<Layout>>>()?;
match Layout::broadcast(layouts, None) {
Ok(layout) => Some(layout),
Err(e) => {
self.errs
.push(ValidationError::new(format!("{e}"), ast.span));
None
}
}
}
pub fn get_layout(&mut self, ast: &Ast, indices: &Vec<char>) -> Option<Layout> {
let layout = match &ast.kind {
AstKind::Assignment(a) => self.get_layout(a.expr.as_ref(), indices),
AstKind::Binop(binop) => {
self.get_layout_binary_op(binop.left.as_ref(), binop.right.as_ref(), binop, indices)
}
AstKind::Monop(monop) => self.get_layout(monop.child.as_ref(), indices),
AstKind::Call(call) => self.get_layout_call(call, ast, indices),
AstKind::CallArg(arg) => self.get_layout(arg.expression.as_ref(), indices),
AstKind::Number(_) => Some(Layout::new_scalar()),
AstKind::Integer(_) => Some(Layout::new_scalar()),
AstKind::Domain(d) => Some(Layout::new_dense(Shape::zeros(1) + d.dim)),
AstKind::Name(name) => self.get_layout_name(
name.name,
ast,
&name.indices,
indices,
name.indice.as_ref().map(|i| i.as_ref()),
),
_ => panic!("unrecognised ast node {:#?}", ast.kind),
};
if log_enabled!(Level::Debug) {
let indices_str = layout.as_ref().map(|l| {
l.explicit_indices()
.iter()
.map(|i| {
i.into_iter()
.map(|x| x.to_string())
.collect::<Vec<String>>()
.join(", ")
})
.collect::<Vec<String>>()
});
debug!(
"layout for ast {} with indices {:?} is {} with indices {:?}",
ast,
indices,
layout.as_ref().unwrap_or(&Layout::new_scalar()),
indices_str.unwrap_or_default()
);
}
layout
}
pub fn get_layout_tensor_elmt(
&mut self,
elmt: &ast::TensorElmt,
indices: &[char],
) -> Option<(Layout, Layout)> {
let expr_indices = elmt.expr.get_indices();
let mut new_indices = indices.to_vec();
for i in expr_indices {
if !indices.contains(&i) && !new_indices.contains(&i) {
new_indices.push(i);
}
}
if new_indices.len() > indices.len() && (new_indices.len() != 2 || indices.len() != 1) {
self.errs.push(ValidationError::new(
format!(
"contraction only supported from 2D to 1D tensors. Got {}D to {}D",
new_indices.len(),
indices.len()
),
elmt.expr.span,
));
return None;
}
debug!(
"calculating expr layout for tensor element with expr: {}",
elmt.expr
);
let expr_layout = self.get_layout(elmt.expr.as_ref(), &new_indices)?;
let expr_layout_to_rank = if new_indices.len() > indices.len() {
match expr_layout.contract_last_axis() {
Ok(layout) => layout,
Err(e) => {
self.errs
.push(ValidationError::new(format!("{e}"), elmt.expr.span));
return None;
}
}
} else {
expr_layout.broadcast_to_rank(indices.len())
};
let elmt_layout = if let Some(elmt_indices) = elmt.indices.as_ref() {
let given_indices_ast = &elmt_indices.kind.as_vector().unwrap().data;
let given_indices: Vec<&ast::Indice> = given_indices_ast
.iter()
.map(|i| i.kind.as_indice().unwrap())
.collect();
if given_indices.len() != indices.len() {
self.errs.push(ValidationError::new(
format!(
"number of dimensions of tensor element ({}) does not match number of dimensions of tensor ({})",
given_indices.len(), indices.len()
),
elmt_indices.span,
));
return None;
}
let mut exp_expr_shape = Shape::ones(indices.len());
exp_expr_shape
.slice_mut(s![..expr_layout_to_rank.rank()])
.assign(expr_layout_to_rank.shape());
let all_range_indices = given_indices.iter().all(|i| i.sep == Some(".."));
let mut old_dim = None;
for (i, indice) in given_indices.iter().enumerate() {
let first = indice.first.kind.as_integer().unwrap();
if !all_range_indices && matches!(indice.sep, Some("..")) {
self.errs.push(ValidationError::new(
"can only use range separator if all indices are ranges".to_string(),
given_indices_ast[i].span,
));
}
let dim = if indice.sep.is_some() {
if let Some(second) = &indice.last {
let second = second.kind.as_integer().unwrap();
if second < first {
self.errs.push(ValidationError::new(
"range end must be greater than range start".to_string(),
given_indices_ast[i].span,
));
return None;
}
usize::try_from(second - first).unwrap()
} else {
exp_expr_shape[i]
}
} else {
1usize
};
if all_range_indices && old_dim.is_some() && dim != old_dim.unwrap() {
self.errs.push(ValidationError::new(
"range indices must have the same dimension".to_string(),
given_indices_ast[i].span,
));
return None;
}
old_dim = Some(dim);
exp_expr_shape[i] = dim;
}
if !can_broadcast_to(&exp_expr_shape, expr_layout_to_rank.shape()) {
self.errs.push(ValidationError::new(
format!(
"cannot broadcast expression shape {} to tensor element shape {}",
expr_layout_to_rank.shape(),
exp_expr_shape
),
elmt.expr.span,
));
return None;
}
if all_range_indices && expr_layout_to_rank.kind() == &LayoutKind::Sparse {
self.errs.push(ValidationError::new(
"cannot use range indices with sparse expression".to_string(),
elmt.expr.span,
));
return None;
}
if all_range_indices && expr_layout_to_rank.kind() == &LayoutKind::Diagonal {
self.errs.push(ValidationError::new(
"cannot use range indices with diagonal expression".to_string(),
elmt.expr.span,
));
return None;
}
if all_range_indices {
match Layout::new_diagonal_from(exp_expr_shape, &expr_layout_to_rank) {
Some(layout) => layout,
None => {
self.errs.push(ValidationError::new(
"when using all range indices, the expression layout must be scalar or 1D with dimension matching the range".to_string(),
elmt.expr.span,
));
return None;
}
}
} else {
expr_layout_to_rank.broadcast_to_shape(&exp_expr_shape)
}
} else {
expr_layout_to_rank
};
Some((expr_layout, elmt_layout))
}
pub fn current_span(&self) -> Option<StringSpan> {
self.current_span
}
pub fn set_current_span(&mut self, current_span: Option<StringSpan>) {
self.current_span = current_span;
}
pub fn errs(&self) -> &ValidationErrors {
&self.errs
}
pub fn errs_mut(&mut self) -> &mut ValidationErrors {
&mut self.errs
}
}
impl Default for Env {
fn default() -> Self {
Self::new()
}
}