use std::{collections::HashMap, iter::FromIterator, ops};
use crate::{
arith::{
Constraint, ConstraintSet, MapPrimitiveType, Num, NumArithmetic, ObjectSafeConstraint,
Substitutions, TypeArithmetic,
},
ast::TypeAst,
error::Errors,
types::{ParamConstraints, ParamQuantifier},
PrimitiveType, Type,
};
use arithmetic_parser::{grammars::Grammar, Block};
mod processor;
use self::processor::TypeProcessor;
#[derive(Debug, Clone)]
pub struct TypeEnvironment<Prim: PrimitiveType = Num> {
pub(crate) substitutions: Substitutions<Prim>,
pub(crate) known_constraints: ConstraintSet<Prim>,
variables: HashMap<String, Type<Prim>>,
}
impl<Prim: PrimitiveType> Default for TypeEnvironment<Prim> {
fn default() -> Self {
Self {
variables: HashMap::new(),
known_constraints: Prim::well_known_constraints(),
substitutions: Substitutions::default(),
}
}
}
impl<Prim: PrimitiveType> TypeEnvironment<Prim> {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, name: &str) -> Option<&Type<Prim>> {
self.variables.get(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &Type<Prim>)> + '_ {
self.variables.iter().map(|(name, ty)| (name.as_str(), ty))
}
fn prepare_type(ty: impl Into<Type<Prim>>) -> Type<Prim> {
let mut ty = ty.into();
assert!(ty.is_concrete(), "Type {} is not concrete", ty);
if let Type::Function(function) = &mut ty {
if function.params.is_none() {
ParamQuantifier::set_params(function, ParamConstraints::default());
}
}
ty
}
pub fn insert(&mut self, name: &str, ty: impl Into<Type<Prim>>) -> &mut Self {
self.variables
.insert(name.to_owned(), Self::prepare_type(ty));
self
}
pub fn insert_constraint(&mut self, constraint: impl Constraint<Prim>) -> &mut Self {
self.known_constraints.insert(constraint);
self
}
pub fn insert_object_safe_constraint(
&mut self,
constraint: impl ObjectSafeConstraint<Prim>,
) -> &mut Self {
self.known_constraints.insert_object_safe(constraint);
self
}
pub fn process_statements<'a, T>(
&mut self,
block: &Block<'a, T>,
) -> Result<Type<Prim>, Errors<'a, Prim>>
where
T: Grammar<'a, Type = TypeAst<'a>>,
NumArithmetic: MapPrimitiveType<T::Lit, Prim = Prim> + TypeArithmetic<Prim>,
{
self.process_with_arithmetic(&NumArithmetic::without_comparisons(), block)
}
pub fn process_with_arithmetic<'a, T, A>(
&mut self,
arithmetic: &A,
block: &Block<'a, T>,
) -> Result<Type<Prim>, Errors<'a, Prim>>
where
T: Grammar<'a, Type = TypeAst<'a>>,
A: MapPrimitiveType<T::Lit, Prim = Prim> + TypeArithmetic<Prim>,
{
TypeProcessor::new(self, arithmetic).process_statements(block)
}
}
impl<Prim: PrimitiveType> ops::Index<&str> for TypeEnvironment<Prim> {
type Output = Type<Prim>;
fn index(&self, name: &str) -> &Self::Output {
self.get(name)
.unwrap_or_else(|| panic!("Variable `{}` is not defined", name))
}
}
fn convert_iter<Prim: PrimitiveType, S, Ty, I>(
iter: I,
) -> impl Iterator<Item = (String, Type<Prim>)>
where
I: IntoIterator<Item = (S, Ty)>,
S: Into<String>,
Ty: Into<Type<Prim>>,
{
iter.into_iter()
.map(|(name, ty)| (name.into(), TypeEnvironment::prepare_type(ty)))
}
impl<Prim: PrimitiveType, S, Ty> FromIterator<(S, Ty)> for TypeEnvironment<Prim>
where
S: Into<String>,
Ty: Into<Type<Prim>>,
{
fn from_iter<I: IntoIterator<Item = (S, Ty)>>(iter: I) -> Self {
Self {
variables: convert_iter(iter).collect(),
known_constraints: Prim::well_known_constraints(),
substitutions: Substitutions::default(),
}
}
}
impl<Prim: PrimitiveType, S, Ty> Extend<(S, Ty)> for TypeEnvironment<Prim>
where
S: Into<String>,
Ty: Into<Type<Prim>>,
{
fn extend<I: IntoIterator<Item = (S, Ty)>>(&mut self, iter: I) {
self.variables.extend(convert_iter(iter))
}
}
trait FullArithmetic<Val, Prim: PrimitiveType>:
MapPrimitiveType<Val, Prim = Prim> + TypeArithmetic<Prim>
{
}
impl<Val, Prim: PrimitiveType, T> FullArithmetic<Val, Prim> for T where
T: MapPrimitiveType<Val, Prim = Prim> + TypeArithmetic<Prim>
{
}