1use std::{collections::HashMap, iter::FromIterator, ops};
4
5use crate::{
6 arith::{
7 Constraint, ConstraintSet, MapPrimitiveType, Num, NumArithmetic, ObjectSafeConstraint,
8 Substitutions, TypeArithmetic,
9 },
10 ast::TypeAst,
11 error::Errors,
12 types::{ParamConstraints, ParamQuantifier},
13 PrimitiveType, Type,
14};
15use arithmetic_parser::{grammars::Grammar, Block};
16
17mod processor;
18
19use self::processor::TypeProcessor;
20
21#[derive(Debug, Clone)]
50pub struct TypeEnvironment<Prim: PrimitiveType = Num> {
51 pub(crate) substitutions: Substitutions<Prim>,
52 pub(crate) known_constraints: ConstraintSet<Prim>,
53 variables: HashMap<String, Type<Prim>>,
54}
55
56impl<Prim: PrimitiveType> Default for TypeEnvironment<Prim> {
57 fn default() -> Self {
58 Self {
59 variables: HashMap::new(),
60 known_constraints: Prim::well_known_constraints(),
61 substitutions: Substitutions::default(),
62 }
63 }
64}
65
66impl<Prim: PrimitiveType> TypeEnvironment<Prim> {
67 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn get(&self, name: &str) -> Option<&Type<Prim>> {
74 self.variables.get(name)
75 }
76
77 pub fn iter(&self) -> impl Iterator<Item = (&str, &Type<Prim>)> + '_ {
79 self.variables.iter().map(|(name, ty)| (name.as_str(), ty))
80 }
81
82 fn prepare_type(ty: impl Into<Type<Prim>>) -> Type<Prim> {
83 let mut ty = ty.into();
84 assert!(ty.is_concrete(), "Type {} is not concrete", ty);
85
86 if let Type::Function(function) = &mut ty {
87 if function.params.is_none() {
88 ParamQuantifier::set_params(function, ParamConstraints::default());
89 }
90 }
91 ty
92 }
93
94 pub fn insert(&mut self, name: &str, ty: impl Into<Type<Prim>>) -> &mut Self {
101 self.variables
102 .insert(name.to_owned(), Self::prepare_type(ty));
103 self
104 }
105
106 pub fn insert_constraint(&mut self, constraint: impl Constraint<Prim>) -> &mut Self {
112 self.known_constraints.insert(constraint);
113 self
114 }
115
116 pub fn insert_object_safe_constraint(
122 &mut self,
123 constraint: impl ObjectSafeConstraint<Prim>,
124 ) -> &mut Self {
125 self.known_constraints.insert_object_safe(constraint);
126 self
127 }
128
129 pub fn process_statements<'a, T>(
135 &mut self,
136 block: &Block<'a, T>,
137 ) -> Result<Type<Prim>, Errors<'a, Prim>>
138 where
139 T: Grammar<'a, Type = TypeAst<'a>>,
140 NumArithmetic: MapPrimitiveType<T::Lit, Prim = Prim> + TypeArithmetic<Prim>,
141 {
142 self.process_with_arithmetic(&NumArithmetic::without_comparisons(), block)
143 }
144
145 pub fn process_with_arithmetic<'a, T, A>(
154 &mut self,
155 arithmetic: &A,
156 block: &Block<'a, T>,
157 ) -> Result<Type<Prim>, Errors<'a, Prim>>
158 where
159 T: Grammar<'a, Type = TypeAst<'a>>,
160 A: MapPrimitiveType<T::Lit, Prim = Prim> + TypeArithmetic<Prim>,
161 {
162 TypeProcessor::new(self, arithmetic).process_statements(block)
163 }
164}
165
166impl<Prim: PrimitiveType> ops::Index<&str> for TypeEnvironment<Prim> {
167 type Output = Type<Prim>;
168
169 fn index(&self, name: &str) -> &Self::Output {
170 self.get(name)
171 .unwrap_or_else(|| panic!("Variable `{}` is not defined", name))
172 }
173}
174
175fn convert_iter<Prim: PrimitiveType, S, Ty, I>(
176 iter: I,
177) -> impl Iterator<Item = (String, Type<Prim>)>
178where
179 I: IntoIterator<Item = (S, Ty)>,
180 S: Into<String>,
181 Ty: Into<Type<Prim>>,
182{
183 iter.into_iter()
184 .map(|(name, ty)| (name.into(), TypeEnvironment::prepare_type(ty)))
185}
186
187impl<Prim: PrimitiveType, S, Ty> FromIterator<(S, Ty)> for TypeEnvironment<Prim>
188where
189 S: Into<String>,
190 Ty: Into<Type<Prim>>,
191{
192 fn from_iter<I: IntoIterator<Item = (S, Ty)>>(iter: I) -> Self {
193 Self {
194 variables: convert_iter(iter).collect(),
195 known_constraints: Prim::well_known_constraints(),
196 substitutions: Substitutions::default(),
197 }
198 }
199}
200
201impl<Prim: PrimitiveType, S, Ty> Extend<(S, Ty)> for TypeEnvironment<Prim>
202where
203 S: Into<String>,
204 Ty: Into<Type<Prim>>,
205{
206 fn extend<I: IntoIterator<Item = (S, Ty)>>(&mut self, iter: I) {
207 self.variables.extend(convert_iter(iter))
208 }
209}
210
211trait FullArithmetic<Val, Prim: PrimitiveType>:
213 MapPrimitiveType<Val, Prim = Prim> + TypeArithmetic<Prim>
214{
215}
216
217impl<Val, Prim: PrimitiveType, T> FullArithmetic<Val, Prim> for T where
218 T: MapPrimitiveType<Val, Prim = Prim> + TypeArithmetic<Prim>
219{
220}