use thiserror::Error;
use crate::Statement;
use crate::type_variable::TypeVariable;
use crate::typed_ast::{
DType, DTypeFactor, DefineVariable, Expression, StructInfo, StructKind, Type,
};
#[derive(Debug, Clone)]
pub struct Substitution(pub Vec<(TypeVariable, Type)>);
impl Substitution {
pub fn empty() -> Substitution {
Substitution(vec![])
}
pub fn single(v: TypeVariable, t: Type) -> Substitution {
Substitution(vec![(v, t)])
}
pub fn lookup(&self, v: &TypeVariable) -> Option<&Type> {
self.0.iter().find(|(var, _)| var == v).map(|(_, t)| t)
}
pub fn extend(&mut self, other: Substitution) {
for (_, t) in &mut self.0 {
t.apply(&other).unwrap(); }
self.0.extend(other.0);
}
pub fn append(&mut self, v: TypeVariable, t: Type) {
self.extend(Substitution::single(v, t));
}
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum SubstitutionError {
#[error("Used non-dimension type '{0}' in a dimension expression")]
SubstitutedNonDTypeWithinDType(Type),
}
pub trait ApplySubstitution {
fn apply(&mut self, substitution: &Substitution) -> Result<(), SubstitutionError>;
}
impl ApplySubstitution for Type {
fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
match self {
Type::TVar(v) => {
if let Some(type_) = s.lookup(v) {
*self = type_.clone();
}
Ok(())
}
Type::TPar(n) => {
if let Some(type_) = s.lookup(&TypeVariable::new(n)) {
*self = type_.clone();
}
Ok(())
}
Type::Dimension(dtype) if dtype.deconstruct_as_single_type_variable().is_some() => {
let v = dtype.deconstruct_as_single_type_variable().unwrap();
if let Some(type_) = s.lookup(&v) {
*self = type_.clone();
}
Ok(())
}
Type::Dimension(dtype) => dtype.apply(s),
Type::Boolean => Ok(()),
Type::String => Ok(()),
Type::DateTime => Ok(()),
Type::Fn(param_types, return_type) => {
for param_type in param_types {
param_type.apply(s)?;
}
return_type.apply(s)
}
Type::Struct(info) => {
if let StructKind::Instance(type_args) = &mut info.kind {
for arg in type_args {
arg.apply(s)?;
}
}
for (_, field_type) in info.fields.values_mut() {
field_type.apply(s)?;
}
Ok(())
}
Type::List(element_type) => element_type.apply(s),
}
}
}
impl ApplySubstitution for DType {
fn apply(&mut self, substitution: &Substitution) -> Result<(), SubstitutionError> {
let mut new_dtype = self.clone();
for (f, power) in self.factors() {
match f {
DTypeFactor::TVar(tv) => {
if let Some(type_) = substitution.lookup(tv) {
let dtype = match type_ {
Type::Dimension(dt) => dt.clone(),
Type::TVar(tv) => DType::from_type_variable(tv.clone()),
t => {
return Err(SubstitutionError::SubstitutedNonDTypeWithinDType(
t.clone(),
));
}
};
new_dtype =
new_dtype.divide(&DType::from_type_variable(tv.clone()).power(*power));
new_dtype = new_dtype.multiply(&dtype.power(*power));
}
}
DTypeFactor::TPar(name) => {
let tv = TypeVariable::new(name);
if let Some(type_) = substitution.lookup(&tv) {
let dtype = match type_ {
Type::Dimension(dt) => dt.clone(),
Type::TVar(tv) => DType::from_type_variable(tv.clone()),
t => {
return Err(SubstitutionError::SubstitutedNonDTypeWithinDType(
t.clone(),
));
}
};
new_dtype = new_dtype
.divide(&DType::from_type_parameter(name.clone()).power(*power));
new_dtype = new_dtype.multiply(&dtype.power(*power));
}
}
DTypeFactor::BaseDimension(_) => {}
}
}
*self = new_dtype;
Ok(())
}
}
impl ApplySubstitution for StructInfo {
fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
if let StructKind::Instance(type_args) = &mut self.kind {
for arg in type_args {
arg.apply(s)?;
}
}
for (_, field_type) in self.fields.values_mut() {
field_type.apply(s)?;
}
Ok(())
}
}
impl ApplySubstitution for Expression<'_> {
fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
match self {
Expression::Scalar { type_scheme, .. } => type_scheme.apply(s),
Expression::Identifier { type_scheme, .. } => type_scheme.apply(s),
Expression::UnitIdentifier { type_scheme, .. } => type_scheme.apply(s),
Expression::UnaryOperator {
expr, type_scheme, ..
} => {
expr.apply(s)?;
type_scheme.apply(s)
}
Expression::BinaryOperator {
lhs,
rhs,
type_scheme,
..
} => {
lhs.apply(s)?;
rhs.apply(s)?;
type_scheme.apply(s)
}
Expression::BinaryOperatorForDate {
lhs,
rhs,
type_scheme,
..
} => {
lhs.apply(s)?;
rhs.apply(s)?;
type_scheme.apply(s)
}
Expression::FunctionCall {
args, type_scheme, ..
} => {
for arg in args {
arg.apply(s)?;
}
type_scheme.apply(s)
}
Expression::CallableCall {
callable,
args,
type_scheme,
..
} => {
callable.apply(s)?;
for arg in args {
arg.apply(s)?;
}
type_scheme.apply(s)
}
Expression::Boolean(_, _) => Ok(()),
Expression::Condition {
condition,
then_expr,
else_expr,
..
} => {
condition.apply(s)?;
then_expr.apply(s)?;
else_expr.apply(s)
}
Expression::String(_, _) => Ok(()),
Expression::InstantiateStruct {
fields,
struct_info,
..
} => {
for (_, expr) in fields {
expr.apply(s)?;
}
struct_info.apply(s)
}
Expression::AccessField {
expr,
struct_type,
field_type,
..
} => {
expr.apply(s)?;
struct_type.apply(s)?;
field_type.apply(s)
}
Expression::List {
elements,
type_scheme,
..
} => {
for element in elements {
element.apply(s)?;
}
type_scheme.apply(s)
}
Expression::TypedHole(_, type_) => type_.apply(s),
}
}
}
impl ApplySubstitution for Statement<'_> {
fn apply(&mut self, s: &Substitution) -> Result<(), SubstitutionError> {
match self {
Statement::Expression(e) => e.apply(s),
Statement::DefineVariable(DefineVariable {
expr, type_scheme, ..
}) => {
expr.apply(s)?;
type_scheme.apply(s)
}
Statement::DefineFunction {
body,
local_variables,
fn_type,
..
} => {
for local_variable in local_variables {
local_variable.expr.apply(s)?;
local_variable.type_scheme.apply(s)?;
}
if let Some(body) = body {
body.apply(s)?;
}
fn_type.apply(s)
}
Statement::DefineDimension(_, _) => Ok(()),
Statement::DefineBaseUnit { type_scheme, .. } => type_scheme.apply(s),
Statement::DefineDerivedUnit {
expr, type_scheme, ..
} => {
expr.apply(s)?;
type_scheme.apply(s)
}
Statement::ProcedureCall { args, .. } => {
for arg in args {
arg.apply(s)?;
}
Ok(())
}
Statement::DefineStruct(info) => {
info.apply(s)?;
Ok(())
}
}
}
}