arithmetic_typing/types/
object.rs

1//! Object types.
2
3use std::{
4    collections::{HashMap, HashSet},
5    fmt,
6    iter::{self, FromIterator},
7};
8
9use crate::{
10    arith::Substitutions,
11    error::{ErrorKind, OpErrors},
12    DynConstraints, PrimitiveType, Type,
13};
14
15/// Object type: a collection of named fields with heterogeneous types.
16///
17/// # Notation
18///
19/// Object types are denoted using a brace notation such as `{ x: Num, y: [(Num, 'T)] }`.
20/// Here, `x` and `y` are field names, and `Num` / `[(Num, 'T)]` are types of the corresponding
21/// object fields.
22///
23/// # As constraint
24///
25/// Object types are *exact*; their extensions cannot be unified with the original types.
26/// For example, if a function argument is `{ x: Num, y: Num }`,
27/// the function cannot be called with an arg of type `{ x: Num, y: Num, z: Num }`:
28///
29/// ```
30/// # use arithmetic_parser::grammars::{Parse, F32Grammar};
31/// # use arithmetic_typing::{error::ErrorKind, Annotated, TypeEnvironment};
32/// # use assert_matches::assert_matches;
33/// # fn main() -> anyhow::Result<()> {
34/// let code = r#"
35///     sum_coords = |pt: { x: Num, y: Num }| pt.x + pt.y;
36///     sum_coords(#{ x: 3, y: 4 }); // OK
37///     sum_coords(#{ x: 3, y: 4, z: 5 }); // fails
38/// "#;
39/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
40/// let err = TypeEnvironment::new().process_statements(&ast).unwrap_err();
41/// # assert_eq!(err.len(), 1);
42/// let err = err.iter().next().unwrap();
43/// assert_matches!(err.kind(), ErrorKind::FieldsMismatch { .. });
44/// # Ok(())
45/// # }
46/// ```
47///
48/// To bridge this gap, objects can be used as a constraint on types, similarly to [`Constraint`]s.
49/// As a constraint, an object specifies *necessary* fields, which can be arbitrarily extended.
50///
51/// The type inference algorithm uses object constraints, not concrete object types whenever
52/// possible:
53///
54/// ```
55/// # use arithmetic_parser::grammars::{Parse, F32Grammar};
56/// # use arithmetic_typing::{error::ErrorKind, Annotated, TypeEnvironment};
57/// # use assert_matches::assert_matches;
58/// # fn main() -> anyhow::Result<()> {
59/// let code = r#"
60///     sum_coords = |pt| pt.x + pt.y;
61///     sum_coords(#{ x: 3, y: 4 }); // OK
62///     sum_coords(#{ x: 3, y: 4, z: 5 }); // also OK
63/// "#;
64/// let ast = Annotated::<F32Grammar>::parse_statements(code)?;
65/// let mut env = TypeEnvironment::new();
66/// env.process_statements(&ast)?;
67/// assert_eq!(
68///     env["sum_coords"].to_string(),
69///     "for<'T: { x: 'U, y: 'U }, 'U: Ops> ('T) -> 'U"
70/// );
71/// # Ok(())
72/// # }
73/// ```
74///
75/// Note that the object constraint in this case refers to another type param, which is
76/// constrained on its own!
77///
78/// [`Constraint`]: crate::arith::Constraint
79#[derive(Debug, Clone, PartialEq)]
80pub struct Object<Prim: PrimitiveType> {
81    fields: HashMap<String, Type<Prim>>,
82}
83
84impl<Prim: PrimitiveType> Default for Object<Prim> {
85    fn default() -> Self {
86        Self {
87            fields: HashMap::new(),
88        }
89    }
90}
91
92impl<Prim, S, V> FromIterator<(S, V)> for Object<Prim>
93where
94    Prim: PrimitiveType,
95    S: Into<String>,
96    V: Into<Type<Prim>>,
97{
98    fn from_iter<T: IntoIterator<Item = (S, V)>>(iter: T) -> Self {
99        Self {
100            fields: iter
101                .into_iter()
102                .map(|(name, ty)| (name.into(), ty.into()))
103                .collect(),
104        }
105    }
106}
107
108impl<Prim: PrimitiveType> fmt::Display for Object<Prim> {
109    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
110        let mut sorted_fields: Vec<_> = self.fields.iter().collect();
111        sorted_fields.sort_unstable_by_key(|(name, _)| *name);
112
113        formatter.write_str("{")?;
114        for (i, (name, ty)) in sorted_fields.into_iter().enumerate() {
115            write!(formatter, " {}: {}", name, ty)?;
116            if i + 1 < self.fields.len() {
117                formatter.write_str(",")?;
118            }
119        }
120        formatter.write_str(" }")
121    }
122}
123
124impl<Prim: PrimitiveType> Object<Prim> {
125    /// Creates an empty object.
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    /// Creates an object with a single field.
131    pub fn just(field: impl Into<String>, ty: impl Into<Type<Prim>>) -> Self {
132        Self {
133            fields: iter::once((field.into(), ty.into())).collect(),
134        }
135    }
136
137    pub(crate) fn from_map(fields: HashMap<String, Type<Prim>>) -> Self {
138        Self { fields }
139    }
140
141    /// Returns type of a field with the specified `name`.
142    pub fn field(&self, name: &str) -> Option<&Type<Prim>> {
143        self.fields.get(name)
144    }
145
146    /// Iterates over fields in this object.
147    pub fn iter(&self) -> impl Iterator<Item = (&str, &Type<Prim>)> + '_ {
148        self.fields.iter().map(|(name, ty)| (name.as_str(), ty))
149    }
150
151    /// Iterates over field names in this object.
152    pub fn field_names(&self) -> impl Iterator<Item = &str> + '_ {
153        self.fields.keys().map(String::as_str)
154    }
155
156    /// Converts this object into a corresponding dynamic constraint.
157    pub fn into_dyn(self) -> Type<Prim> {
158        Type::Dyn(DynConstraints::from(self))
159    }
160
161    pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = (&str, &mut Type<Prim>)> + '_ {
162        self.fields.iter_mut().map(|(name, ty)| (name.as_str(), ty))
163    }
164
165    pub(crate) fn is_concrete(&self) -> bool {
166        self.fields.values().all(Type::is_concrete)
167    }
168
169    pub(crate) fn extend_from(
170        &mut self,
171        other: Self,
172        substitutions: &mut Substitutions<Prim>,
173        mut errors: OpErrors<'_, Prim>,
174    ) {
175        for (field_name, ty) in other.fields {
176            if let Some(this_field) = self.fields.get(&field_name) {
177                substitutions.unify(this_field, &ty, errors.with_location(field_name.as_str()));
178            } else {
179                self.fields.insert(field_name, ty);
180            }
181        }
182    }
183
184    pub(crate) fn apply_as_constraint(
185        &self,
186        ty: &Type<Prim>,
187        substitutions: &mut Substitutions<Prim>,
188        mut errors: OpErrors<'_, Prim>,
189    ) {
190        let resolved_ty = if let Type::Var(var) = ty {
191            debug_assert!(var.is_free());
192            substitutions.insert_obj_constraint(var.index(), self, errors.by_ref());
193            substitutions.fast_resolve(ty)
194        } else {
195            ty
196        };
197
198        match resolved_ty {
199            Type::Object(rhs) => {
200                self.constraint_object(&rhs.clone(), substitutions, errors);
201            }
202            Type::Dyn(constraints) => {
203                if let Some(object) = constraints.inner.object.clone() {
204                    self.constraint_object(&object, substitutions, errors);
205                } else {
206                    errors.push(ErrorKind::CannotAccessFields);
207                }
208            }
209            Type::Any | Type::Var(_) => { /* OK */ }
210            _ => errors.push(ErrorKind::CannotAccessFields),
211        }
212    }
213
214    /// Places an object constraint encoded in `lhs` on a (concrete) object in `rhs`.
215    fn constraint_object(
216        &self,
217        rhs: &Object<Prim>,
218        substitutions: &mut Substitutions<Prim>,
219        mut errors: OpErrors<'_, Prim>,
220    ) {
221        let mut missing_fields = HashSet::new();
222        for (field_name, lhs_ty) in self.iter() {
223            if let Some(rhs_ty) = rhs.field(field_name) {
224                substitutions.unify(lhs_ty, rhs_ty, errors.with_location(field_name));
225            } else {
226                missing_fields.insert(field_name.to_owned());
227            }
228        }
229
230        if !missing_fields.is_empty() {
231            errors.push(ErrorKind::MissingFields {
232                fields: missing_fields,
233                available_fields: rhs.field_names().map(String::from).collect(),
234            });
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::arith::Num;
243
244    use assert_matches::assert_matches;
245
246    fn get_err(errors: OpErrors<'_, Num>) -> ErrorKind<Num> {
247        let mut errors = errors.into_vec();
248        assert_eq!(errors.len(), 1, "{:?}", errors);
249        errors.pop().unwrap()
250    }
251
252    #[test]
253    fn placing_obj_constraint() {
254        let lhs: Object<Num> = vec![("x", Type::NUM)].into_iter().collect();
255        let mut substitutions = Substitutions::default();
256        let mut errors = OpErrors::new();
257        lhs.constraint_object(&lhs, &mut substitutions, errors.by_ref());
258        assert!(errors.into_vec().is_empty());
259
260        let var_rhs = vec![("x", Type::free_var(0))].into_iter().collect();
261        let mut errors = OpErrors::new();
262        lhs.constraint_object(&var_rhs, &mut substitutions, errors.by_ref());
263        assert!(errors.into_vec().is_empty());
264        assert_eq!(*substitutions.fast_resolve(&Type::free_var(0)), Type::NUM);
265
266        // Extra fields in RHS are fine.
267        let extra_rhs = vec![("x", Type::free_var(1)), ("y", Type::BOOL)]
268            .into_iter()
269            .collect();
270        let mut errors = OpErrors::new();
271        lhs.constraint_object(&extra_rhs, &mut substitutions, errors.by_ref());
272        assert!(errors.into_vec().is_empty());
273        assert_eq!(*substitutions.fast_resolve(&Type::free_var(1)), Type::NUM);
274
275        let missing_field_rhs = vec![("y", Type::free_var(2))].into_iter().collect();
276        let mut errors = OpErrors::new();
277        lhs.constraint_object(&missing_field_rhs, &mut substitutions, errors.by_ref());
278        assert_matches!(
279            get_err(errors),
280            ErrorKind::MissingFields { fields, available_fields }
281                if fields.len() == 1 && fields.contains("x") &&
282                available_fields.len() == 1 && available_fields.contains("y")
283        );
284
285        let incompatible_field_rhs = vec![("x", Type::BOOL)].into_iter().collect();
286        let mut errors = OpErrors::new();
287        lhs.constraint_object(&incompatible_field_rhs, &mut substitutions, errors.by_ref());
288        assert_matches!(
289            get_err(errors),
290            ErrorKind::TypeMismatch(lhs, rhs) if lhs == Type::NUM && rhs == Type::BOOL
291        );
292    }
293}