1use std::{
4 collections::{HashMap, HashSet},
5 fmt,
6 iter::{self, FromIterator},
7};
8
9use crate::{
10 arith::Substitutions,
11 error::{ErrorKind, OpErrors},
12 DynConstraints, PrimitiveType, Type,
13};
14
15#[derive(Debug, Clone, PartialEq)]
80pub struct Object<Prim: PrimitiveType> {
81 fields: HashMap<String, Type<Prim>>,
82}
83
84impl<Prim: PrimitiveType> Default for Object<Prim> {
85 fn default() -> Self {
86 Self {
87 fields: HashMap::new(),
88 }
89 }
90}
91
92impl<Prim, S, V> FromIterator<(S, V)> for Object<Prim>
93where
94 Prim: PrimitiveType,
95 S: Into<String>,
96 V: Into<Type<Prim>>,
97{
98 fn from_iter<T: IntoIterator<Item = (S, V)>>(iter: T) -> Self {
99 Self {
100 fields: iter
101 .into_iter()
102 .map(|(name, ty)| (name.into(), ty.into()))
103 .collect(),
104 }
105 }
106}
107
108impl<Prim: PrimitiveType> fmt::Display for Object<Prim> {
109 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
110 let mut sorted_fields: Vec<_> = self.fields.iter().collect();
111 sorted_fields.sort_unstable_by_key(|(name, _)| *name);
112
113 formatter.write_str("{")?;
114 for (i, (name, ty)) in sorted_fields.into_iter().enumerate() {
115 write!(formatter, " {}: {}", name, ty)?;
116 if i + 1 < self.fields.len() {
117 formatter.write_str(",")?;
118 }
119 }
120 formatter.write_str(" }")
121 }
122}
123
124impl<Prim: PrimitiveType> Object<Prim> {
125 pub fn new() -> Self {
127 Self::default()
128 }
129
130 pub fn just(field: impl Into<String>, ty: impl Into<Type<Prim>>) -> Self {
132 Self {
133 fields: iter::once((field.into(), ty.into())).collect(),
134 }
135 }
136
137 pub(crate) fn from_map(fields: HashMap<String, Type<Prim>>) -> Self {
138 Self { fields }
139 }
140
141 pub fn field(&self, name: &str) -> Option<&Type<Prim>> {
143 self.fields.get(name)
144 }
145
146 pub fn iter(&self) -> impl Iterator<Item = (&str, &Type<Prim>)> + '_ {
148 self.fields.iter().map(|(name, ty)| (name.as_str(), ty))
149 }
150
151 pub fn field_names(&self) -> impl Iterator<Item = &str> + '_ {
153 self.fields.keys().map(String::as_str)
154 }
155
156 pub fn into_dyn(self) -> Type<Prim> {
158 Type::Dyn(DynConstraints::from(self))
159 }
160
161 pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = (&str, &mut Type<Prim>)> + '_ {
162 self.fields.iter_mut().map(|(name, ty)| (name.as_str(), ty))
163 }
164
165 pub(crate) fn is_concrete(&self) -> bool {
166 self.fields.values().all(Type::is_concrete)
167 }
168
169 pub(crate) fn extend_from(
170 &mut self,
171 other: Self,
172 substitutions: &mut Substitutions<Prim>,
173 mut errors: OpErrors<'_, Prim>,
174 ) {
175 for (field_name, ty) in other.fields {
176 if let Some(this_field) = self.fields.get(&field_name) {
177 substitutions.unify(this_field, &ty, errors.with_location(field_name.as_str()));
178 } else {
179 self.fields.insert(field_name, ty);
180 }
181 }
182 }
183
184 pub(crate) fn apply_as_constraint(
185 &self,
186 ty: &Type<Prim>,
187 substitutions: &mut Substitutions<Prim>,
188 mut errors: OpErrors<'_, Prim>,
189 ) {
190 let resolved_ty = if let Type::Var(var) = ty {
191 debug_assert!(var.is_free());
192 substitutions.insert_obj_constraint(var.index(), self, errors.by_ref());
193 substitutions.fast_resolve(ty)
194 } else {
195 ty
196 };
197
198 match resolved_ty {
199 Type::Object(rhs) => {
200 self.constraint_object(&rhs.clone(), substitutions, errors);
201 }
202 Type::Dyn(constraints) => {
203 if let Some(object) = constraints.inner.object.clone() {
204 self.constraint_object(&object, substitutions, errors);
205 } else {
206 errors.push(ErrorKind::CannotAccessFields);
207 }
208 }
209 Type::Any | Type::Var(_) => { }
210 _ => errors.push(ErrorKind::CannotAccessFields),
211 }
212 }
213
214 fn constraint_object(
216 &self,
217 rhs: &Object<Prim>,
218 substitutions: &mut Substitutions<Prim>,
219 mut errors: OpErrors<'_, Prim>,
220 ) {
221 let mut missing_fields = HashSet::new();
222 for (field_name, lhs_ty) in self.iter() {
223 if let Some(rhs_ty) = rhs.field(field_name) {
224 substitutions.unify(lhs_ty, rhs_ty, errors.with_location(field_name));
225 } else {
226 missing_fields.insert(field_name.to_owned());
227 }
228 }
229
230 if !missing_fields.is_empty() {
231 errors.push(ErrorKind::MissingFields {
232 fields: missing_fields,
233 available_fields: rhs.field_names().map(String::from).collect(),
234 });
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::arith::Num;
243
244 use assert_matches::assert_matches;
245
246 fn get_err(errors: OpErrors<'_, Num>) -> ErrorKind<Num> {
247 let mut errors = errors.into_vec();
248 assert_eq!(errors.len(), 1, "{:?}", errors);
249 errors.pop().unwrap()
250 }
251
252 #[test]
253 fn placing_obj_constraint() {
254 let lhs: Object<Num> = vec![("x", Type::NUM)].into_iter().collect();
255 let mut substitutions = Substitutions::default();
256 let mut errors = OpErrors::new();
257 lhs.constraint_object(&lhs, &mut substitutions, errors.by_ref());
258 assert!(errors.into_vec().is_empty());
259
260 let var_rhs = vec![("x", Type::free_var(0))].into_iter().collect();
261 let mut errors = OpErrors::new();
262 lhs.constraint_object(&var_rhs, &mut substitutions, errors.by_ref());
263 assert!(errors.into_vec().is_empty());
264 assert_eq!(*substitutions.fast_resolve(&Type::free_var(0)), Type::NUM);
265
266 let extra_rhs = vec![("x", Type::free_var(1)), ("y", Type::BOOL)]
268 .into_iter()
269 .collect();
270 let mut errors = OpErrors::new();
271 lhs.constraint_object(&extra_rhs, &mut substitutions, errors.by_ref());
272 assert!(errors.into_vec().is_empty());
273 assert_eq!(*substitutions.fast_resolve(&Type::free_var(1)), Type::NUM);
274
275 let missing_field_rhs = vec![("y", Type::free_var(2))].into_iter().collect();
276 let mut errors = OpErrors::new();
277 lhs.constraint_object(&missing_field_rhs, &mut substitutions, errors.by_ref());
278 assert_matches!(
279 get_err(errors),
280 ErrorKind::MissingFields { fields, available_fields }
281 if fields.len() == 1 && fields.contains("x") &&
282 available_fields.len() == 1 && available_fields.contains("y")
283 );
284
285 let incompatible_field_rhs = vec![("x", Type::BOOL)].into_iter().collect();
286 let mut errors = OpErrors::new();
287 lhs.constraint_object(&incompatible_field_rhs, &mut substitutions, errors.by_ref());
288 assert_matches!(
289 get_err(errors),
290 ErrorKind::TypeMismatch(lhs, rhs) if lhs == Type::NUM && rhs == Type::BOOL
291 );
292 }
293}