arithmetic_typing/arith/
constraints.rs

1//! `TypeConstraints` and implementations.
2
3use std::{collections::HashMap, fmt, marker::PhantomData};
4
5use crate::{
6    arith::Substitutions,
7    error::{ErrorKind, OpErrors},
8    visit::{self, Visit},
9    Function, Object, PrimitiveType, Slice, Tuple, Type, TypeVar,
10};
11
12/// Constraint that can be placed on [`Type`]s.
13///
14/// Constraints can be placed on [`Function`] type variables, and can be applied
15/// to types in [`TypeArithmetic`] impls. For example, [`NumArithmetic`] places
16/// the [`Linearity`] constraint on types involved in arithmetic ops.
17///
18/// The constraint mechanism is similar to trait constraints in Rust, but is much more limited:
19///
20/// - Constraints cannot be parametric (cf. parameters in traits, such `AsRef<_>`
21///   or `Iterator<Item = _>`).
22/// - Constraints are applied to types in separation; it is impossible to create a constraint
23///   involving several type variables.
24/// - Constraints cannot contradict each other.
25///
26/// # Implementation rules
27///
28/// - [`Display`](fmt::Display) must display constraint as an identifier (e.g., `Lin`).
29///   The string presentation of a constraint must be unique within a [`PrimitiveType`];
30///   it is used to identify constraints in a [`ConstraintSet`].
31///
32/// [`TypeArithmetic`]: crate::arith::TypeArithmetic
33/// [`NumArithmetic`]: crate::arith::NumArithmetic
34pub trait Constraint<Prim: PrimitiveType>: fmt::Display + Send + Sync + 'static {
35    /// Returns a [`Visit`]or that will be applied to constrained [`Type`]s. The visitor
36    /// may use `substitutions` to resolve types and `errors` to record constraint errors.
37    ///
38    /// # Tips
39    ///
40    /// - You can use [`StructConstraint`] for typical use cases, which involve recursively
41    ///   traversing `ty`.
42    fn visitor<'r>(
43        &self,
44        substitutions: &'r mut Substitutions<Prim>,
45        errors: OpErrors<'r, Prim>,
46    ) -> Box<dyn Visit<Prim> + 'r>;
47
48    /// Clones this constraint into a `Box`.
49    ///
50    /// This method should be implemented by implementing [`Clone`] and boxing its output.
51    fn clone_boxed(&self) -> Box<dyn Constraint<Prim>>;
52}
53
54impl<Prim: PrimitiveType> fmt::Debug for dyn Constraint<Prim> {
55    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
56        formatter
57            .debug_tuple("dyn Constraint")
58            .field(&self.to_string())
59            .finish()
60    }
61}
62
63impl<Prim: PrimitiveType> Clone for Box<dyn Constraint<Prim>> {
64    fn clone(&self) -> Self {
65        self.clone_boxed()
66    }
67}
68
69/// Marker trait for object-safe constraints, i.e., constraints that can be included
70/// into a [`DynConstraints`](crate::DynConstraints).
71///
72/// Object safety is similar to this notion in Rust. For a constraint `C` to be object-safe,
73/// it should be the case that `dyn C` (the untagged union of all types implementing `C`)
74/// implements `C`. As an example, this is the case for [`Linearity`], but is not the case
75/// for [`Ops`]. Indeed, [`Ops`] requires the type to be addable to itself,
76/// which would be impossible for `dyn Ops`.
77pub trait ObjectSafeConstraint<Prim: PrimitiveType>: Constraint<Prim> {}
78
79/// Helper to define *structural* [`Constraint`]s, i.e., constraints recursively checking
80/// the provided type.
81///
82/// The following logic is used to check whether a type satisfies the constraint:
83///
84/// - Primitive types satisfy the constraint iff the predicate provided in [`Self::new()`]
85///   returns `true`.
86/// - [`Type::Any`] always satisfies the constraint.
87/// - [`Type::Dyn`] types satisfy the constraint iff the [`Constraint`] wrapped by this helper
88///   is present among [`DynConstraints`](crate::DynConstraints). Thus,
89///   if the wrapped constraint is not [object-safe](ObjectSafeConstraint), it will not be satisfied
90///   by any `Dyn` type.
91/// - Functional types never satisfy the constraint.
92/// - A compound type (i.e., a tuple) satisfies the constraint iff all its items satisfy
93///   the constraint.
94/// - If [`Self::deny_dyn_slices()`] is set, tuple types need to have static length.
95///
96/// # Examples
97///
98/// Defining a constraint type using `StructConstraint`:
99///
100/// ```
101/// # use arithmetic_typing::{
102/// #     arith::{Constraint, StructConstraint, Substitutions}, error::OpErrors, visit::Visit,
103/// #     PrimitiveType, Type,
104/// # };
105/// # use std::fmt;
106/// /// Constraint for hashable types.
107/// #[derive(Clone, Copy)]
108/// struct Hashed;
109///
110/// impl fmt::Display for Hashed {
111///     fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
112///         formatter.write_str("Hash")
113///     }
114/// }
115///
116/// impl<Prim: PrimitiveType> Constraint<Prim> for Hashed {
117///     fn visitor<'r>(
118///         &self,
119///         substitutions: &'r mut Substitutions<Prim>,
120///         errors: OpErrors<'r, Prim>,
121///     ) -> Box<dyn Visit<Prim> + 'r> {
122///         // We can hash everything except for functions (and thus,
123///         // types containing functions).
124///         StructConstraint::new(*self, |_| true)
125///             .visitor(substitutions, errors)
126///     }
127///
128///     fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
129///         Box::new(*self)
130///     }
131/// }
132/// ```
133#[derive(Debug)]
134pub struct StructConstraint<Prim, C, F> {
135    constraint: C,
136    predicate: F,
137    deny_dyn_slices: bool,
138    _prim: PhantomData<Prim>,
139}
140
141impl<Prim, C, F> StructConstraint<Prim, C, F>
142where
143    Prim: PrimitiveType,
144    C: Constraint<Prim> + Clone,
145    F: Fn(&Prim) -> bool + 'static,
146{
147    /// Creates a new helper. `predicate` determines whether a particular primitive type
148    /// should satisfy the `constraint`.
149    pub fn new(constraint: C, predicate: F) -> Self {
150        Self {
151            constraint,
152            predicate,
153            deny_dyn_slices: false,
154            _prim: PhantomData,
155        }
156    }
157
158    /// Marks that dynamically sized slices should fail the constraint check.
159    pub fn deny_dyn_slices(mut self) -> Self {
160        self.deny_dyn_slices = true;
161        self
162    }
163
164    /// Returns a [`Visit`]or that can be used for [`Constraint::visitor()`] implementations.
165    pub fn visitor<'r>(
166        self,
167        substitutions: &'r mut Substitutions<Prim>,
168        errors: OpErrors<'r, Prim>,
169    ) -> Box<dyn Visit<Prim> + 'r> {
170        Box::new(StructConstraintVisitor {
171            inner: self,
172            substitutions,
173            errors,
174        })
175    }
176}
177
178#[derive(Debug)]
179struct StructConstraintVisitor<'r, Prim: PrimitiveType, C, F> {
180    inner: StructConstraint<Prim, C, F>,
181    substitutions: &'r mut Substitutions<Prim>,
182    errors: OpErrors<'r, Prim>,
183}
184
185impl<'r, Prim, C, F> Visit<Prim> for StructConstraintVisitor<'r, Prim, C, F>
186where
187    Prim: PrimitiveType,
188    C: Constraint<Prim> + Clone,
189    F: Fn(&Prim) -> bool + 'static,
190{
191    fn visit_type(&mut self, ty: &Type<Prim>) {
192        match ty {
193            Type::Dyn(constraints) => {
194                if !constraints.inner.simple.contains(&self.inner.constraint) {
195                    self.errors.push(ErrorKind::failed_constraint(
196                        ty.clone(),
197                        self.inner.constraint.clone(),
198                    ));
199                }
200            }
201            _ => visit::visit_type(self, ty),
202        }
203    }
204
205    fn visit_var(&mut self, var: TypeVar) {
206        debug_assert!(var.is_free());
207        self.substitutions.insert_constraint(
208            var.index(),
209            &self.inner.constraint,
210            self.errors.by_ref(),
211        );
212
213        let resolved = self.substitutions.fast_resolve(&Type::Var(var)).clone();
214        if let Type::Var(_) = resolved {
215            // Avoid infinite recursion.
216        } else {
217            visit::visit_type(self, &resolved);
218        }
219    }
220
221    fn visit_primitive(&mut self, primitive: &Prim) {
222        if !(self.inner.predicate)(primitive) {
223            self.errors.push(ErrorKind::failed_constraint(
224                Type::Prim(primitive.clone()),
225                self.inner.constraint.clone(),
226            ));
227        }
228    }
229
230    fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
231        if self.inner.deny_dyn_slices {
232            let middle_len = tuple.parts().1.map(Slice::len);
233            if let Some(middle_len) = middle_len {
234                if let Err(err) = self.substitutions.apply_static_len(middle_len) {
235                    self.errors.push(err);
236                }
237            }
238        }
239
240        for (i, element) in tuple.element_types() {
241            self.errors.push_location(i);
242            self.visit_type(element);
243            self.errors.pop_location();
244        }
245    }
246
247    fn visit_object(&mut self, obj: &Object<Prim>) {
248        for (name, element) in obj.iter() {
249            self.errors.push_location(name);
250            self.visit_type(element);
251            self.errors.pop_location();
252        }
253    }
254
255    fn visit_function(&mut self, function: &Function<Prim>) {
256        self.errors.push(ErrorKind::failed_constraint(
257            function.clone().into(),
258            self.inner.constraint.clone(),
259        ));
260    }
261}
262
263/// [`Constraint`] for numeric types that can be subject to unary `-` and can participate
264/// in `T op Num` / `Num op T` operations.
265///
266/// Defined recursively as [linear](LinearType) primitive types and tuples / objects consisting
267/// of linear types.
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub struct Linearity;
270
271impl fmt::Display for Linearity {
272    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
273        formatter.write_str("Lin")
274    }
275}
276
277impl<Prim: LinearType> Constraint<Prim> for Linearity {
278    fn visitor<'r>(
279        &self,
280        substitutions: &'r mut Substitutions<Prim>,
281        errors: OpErrors<'r, Prim>,
282    ) -> Box<dyn Visit<Prim> + 'r> {
283        StructConstraint::new(*self, LinearType::is_linear).visitor(substitutions, errors)
284    }
285
286    fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
287        Box::new(*self)
288    }
289}
290
291impl<Prim: LinearType> ObjectSafeConstraint<Prim> for Linearity {}
292
293/// Primitive type which supports a notion of *linearity*. Linear types are types that
294/// can be used in arithmetic ops.
295pub trait LinearType: PrimitiveType {
296    /// Returns `true` iff this type is linear.
297    fn is_linear(&self) -> bool;
298}
299
300/// [`Constraint`] for numeric types that can participate in binary arithmetic ops (`T op T`).
301///
302/// Defined as a subset of `Lin` types without dynamically sized slices and
303/// any types containing dynamically sized slices.
304#[derive(Debug, Clone, Copy, PartialEq, Eq)]
305pub struct Ops;
306
307impl fmt::Display for Ops {
308    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
309        formatter.write_str("Ops")
310    }
311}
312
313impl<Prim: LinearType> Constraint<Prim> for Ops {
314    fn visitor<'r>(
315        &self,
316        substitutions: &'r mut Substitutions<Prim>,
317        errors: OpErrors<'r, Prim>,
318    ) -> Box<dyn Visit<Prim> + 'r> {
319        StructConstraint::new(*self, LinearType::is_linear)
320            .deny_dyn_slices()
321            .visitor(substitutions, errors)
322    }
323
324    fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
325        Box::new(*self)
326    }
327}
328
329/// Set of [`Constraint`]s.
330///
331/// [`Display`](fmt::Display)ed as `Foo + Bar + Quux`, where `Foo`, `Bar` and `Quux` are
332/// constraints in the set.
333#[derive(Debug, Clone)]
334pub struct ConstraintSet<Prim: PrimitiveType> {
335    inner: HashMap<String, (Box<dyn Constraint<Prim>>, bool)>,
336}
337
338impl<Prim: PrimitiveType> Default for ConstraintSet<Prim> {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344impl<Prim: PrimitiveType> PartialEq for ConstraintSet<Prim> {
345    fn eq(&self, other: &Self) -> bool {
346        if self.inner.len() == other.inner.len() {
347            self.inner.keys().all(|key| other.inner.contains_key(key))
348        } else {
349            false
350        }
351    }
352}
353
354impl<Prim: PrimitiveType> fmt::Display for ConstraintSet<Prim> {
355    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
356        let len = self.inner.len();
357        for (i, (constraint, _)) in self.inner.values().enumerate() {
358            fmt::Display::fmt(constraint, formatter)?;
359            if i + 1 < len {
360                formatter.write_str(" + ")?;
361            }
362        }
363        Ok(())
364    }
365}
366
367impl<Prim: PrimitiveType> ConstraintSet<Prim> {
368    /// Creates an empty set.
369    pub fn new() -> Self {
370        Self {
371            inner: HashMap::new(),
372        }
373    }
374
375    /// Creates a set with one constraint.
376    pub fn just(constraint: impl Constraint<Prim>) -> Self {
377        let mut this = Self::new();
378        this.insert(constraint);
379        this
380    }
381
382    /// Checks if this constraint set is empty.
383    pub fn is_empty(&self) -> bool {
384        self.inner.is_empty()
385    }
386
387    fn contains(&self, constraint: &impl Constraint<Prim>) -> bool {
388        self.inner.contains_key(&constraint.to_string())
389    }
390
391    /// Inserts a constraint into this set.
392    pub fn insert(&mut self, constraint: impl Constraint<Prim>) {
393        self.inner
394            .insert(constraint.to_string(), (Box::new(constraint), false));
395    }
396
397    /// Inserts an object-safe constraint into this set.
398    pub fn insert_object_safe(&mut self, constraint: impl ObjectSafeConstraint<Prim>) {
399        self.inner
400            .insert(constraint.to_string(), (Box::new(constraint), true));
401    }
402
403    /// Inserts a boxed constraint into this set.
404    pub(crate) fn insert_boxed(&mut self, constraint: Box<dyn Constraint<Prim>>) {
405        self.inner
406            .insert(constraint.to_string(), (constraint, false));
407    }
408
409    /// Returns the link to constraint and an indicator whether it is object-safe.
410    pub(crate) fn get_by_name(&self, name: &str) -> Option<(&dyn Constraint<Prim>, bool)> {
411        self.inner
412            .get(name)
413            .map(|(constraint, is_object_safe)| (constraint.as_ref(), *is_object_safe))
414    }
415
416    /// Applies all constraints from this set.
417    pub(crate) fn apply_all(
418        &self,
419        ty: &Type<Prim>,
420        substitutions: &mut Substitutions<Prim>,
421        mut errors: OpErrors<'_, Prim>,
422    ) {
423        for (constraint, _) in self.inner.values() {
424            constraint
425                .visitor(substitutions, errors.by_ref())
426                .visit_type(ty);
427        }
428    }
429
430    /// Applies all constraints from this set to an object.
431    pub(crate) fn apply_all_to_object(
432        &self,
433        object: &Object<Prim>,
434        substitutions: &mut Substitutions<Prim>,
435        mut errors: OpErrors<'_, Prim>,
436    ) {
437        for (constraint, _) in self.inner.values() {
438            constraint
439                .visitor(substitutions, errors.by_ref())
440                .visit_object(object);
441        }
442    }
443}
444
445/// Extended [`ConstraintSet`] that additionally supports object constraints.
446#[derive(Debug, Clone, PartialEq)]
447pub(crate) struct CompleteConstraints<Prim: PrimitiveType> {
448    pub simple: ConstraintSet<Prim>,
449    /// Object constraint. Stored as `Type` for convenience.
450    pub object: Option<Object<Prim>>,
451}
452
453impl<Prim: PrimitiveType> Default for CompleteConstraints<Prim> {
454    fn default() -> Self {
455        Self {
456            simple: ConstraintSet::new(),
457            object: None,
458        }
459    }
460}
461
462impl<Prim: PrimitiveType> From<ConstraintSet<Prim>> for CompleteConstraints<Prim> {
463    fn from(constraints: ConstraintSet<Prim>) -> Self {
464        Self {
465            simple: constraints,
466            object: None,
467        }
468    }
469}
470
471impl<Prim: PrimitiveType> From<Object<Prim>> for CompleteConstraints<Prim> {
472    fn from(object: Object<Prim>) -> Self {
473        Self {
474            simple: ConstraintSet::default(),
475            object: Some(object),
476        }
477    }
478}
479
480impl<Prim: PrimitiveType> fmt::Display for CompleteConstraints<Prim> {
481    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
482        match (&self.object, self.simple.is_empty()) {
483            (Some(object), false) => write!(formatter, "{} + {}", object, self.simple),
484            (Some(object), true) => fmt::Display::fmt(object, formatter),
485            (None, _) => fmt::Display::fmt(&self.simple, formatter),
486        }
487    }
488}
489
490impl<Prim: PrimitiveType> CompleteConstraints<Prim> {
491    /// Checks if this constraint set is empty.
492    pub fn is_empty(&self) -> bool {
493        self.object.is_none() && self.simple.is_empty()
494    }
495
496    /// Inserts a constraint into this set.
497    pub fn insert(
498        &mut self,
499        constraint: impl Constraint<Prim>,
500        substitutions: &mut Substitutions<Prim>,
501        errors: OpErrors<'_, Prim>,
502    ) {
503        self.simple.insert(constraint);
504        self.check_object_consistency(substitutions, errors);
505    }
506
507    /// Applies all constraints from this set.
508    pub fn apply_all(
509        &self,
510        ty: &Type<Prim>,
511        substitutions: &mut Substitutions<Prim>,
512        mut errors: OpErrors<'_, Prim>,
513    ) {
514        self.simple.apply_all(ty, substitutions, errors.by_ref());
515        if let Some(lhs) = &self.object {
516            lhs.apply_as_constraint(ty, substitutions, errors);
517        }
518    }
519
520    /// Maps the object constraint if present.
521    pub fn map_object(self, map: impl FnOnce(&mut Object<Prim>)) -> Self {
522        Self {
523            simple: self.simple,
524            object: self.object.map(|mut object| {
525                map(&mut object);
526                object
527            }),
528        }
529    }
530
531    /// Inserts an object constraint into this set.
532    pub fn insert_obj_constraint(
533        &mut self,
534        object: Object<Prim>,
535        substitutions: &mut Substitutions<Prim>,
536        mut errors: OpErrors<'_, Prim>,
537    ) {
538        if let Some(existing_object) = &mut self.object {
539            existing_object.extend_from(object, substitutions, errors.by_ref());
540        } else {
541            self.object = Some(object);
542        }
543        self.check_object_consistency(substitutions, errors);
544    }
545
546    fn check_object_consistency(
547        &self,
548        substitutions: &mut Substitutions<Prim>,
549        errors: OpErrors<'_, Prim>,
550    ) {
551        if let Some(object) = &self.object {
552            self.simple
553                .apply_all_to_object(&object, substitutions, errors);
554        }
555    }
556}