arithmetic_typing/arith/
mod.rs

1//! Types allowing to customize various aspects of the type system, such as type constraints
2//! and behavior of unary / binary ops.
3
4use num_traits::NumOps;
5
6use std::{fmt, str::FromStr};
7
8use crate::{
9    error::{ErrorKind, ErrorLocation, OpErrors},
10    PrimitiveType, Type,
11};
12use arithmetic_parser::{BinaryOp, UnaryOp};
13
14mod constraints;
15mod substitutions;
16
17pub(crate) use self::constraints::CompleteConstraints;
18pub use self::constraints::{
19    Constraint, ConstraintSet, LinearType, Linearity, ObjectSafeConstraint, Ops, StructConstraint,
20};
21pub use self::substitutions::Substitutions;
22
23/// Maps a literal value from a certain [`Grammar`] to its type. This assumes that all literals
24/// are primitive.
25///
26/// [`Grammar`]: arithmetic_parser::grammars::Grammar
27pub trait MapPrimitiveType<Val> {
28    /// Types of literals output by this mapper.
29    type Prim: PrimitiveType;
30
31    /// Gets the type of the provided literal value.
32    fn type_of_literal(&self, lit: &Val) -> Self::Prim;
33}
34
35/// Arithmetic allowing to customize primitive types and how unary and binary operations are handled
36/// during type inference.
37///
38/// # Examples
39///
40/// See crate examples for examples how define custom arithmetics.
41pub trait TypeArithmetic<Prim: PrimitiveType> {
42    /// Handles a unary operation.
43    fn process_unary_op(
44        &self,
45        substitutions: &mut Substitutions<Prim>,
46        context: &UnaryOpContext<Prim>,
47        errors: OpErrors<'_, Prim>,
48    ) -> Type<Prim>;
49
50    /// Handles a binary operation.
51    fn process_binary_op(
52        &self,
53        substitutions: &mut Substitutions<Prim>,
54        context: &BinaryOpContext<Prim>,
55        errors: OpErrors<'_, Prim>,
56    ) -> Type<Prim>;
57}
58
59/// Code spans related to a unary operation.
60///
61/// Used in [`TypeArithmetic::process_unary_op()`].
62#[derive(Debug, Clone)]
63pub struct UnaryOpContext<Prim: PrimitiveType> {
64    /// Unary operation.
65    pub op: UnaryOp,
66    /// Operation argument.
67    pub arg: Type<Prim>,
68}
69
70/// Code spans related to a binary operation.
71///
72/// Used in [`TypeArithmetic::process_binary_op()`].
73#[derive(Debug, Clone)]
74pub struct BinaryOpContext<Prim: PrimitiveType> {
75    /// Binary operation.
76    pub op: BinaryOp,
77    /// Spanned left-hand side.
78    pub lhs: Type<Prim>,
79    /// Spanned right-hand side.
80    pub rhs: Type<Prim>,
81}
82
83/// [`PrimitiveType`] that has Boolean type as one of its variants.
84pub trait WithBoolean: PrimitiveType {
85    /// Boolean type.
86    const BOOL: Self;
87}
88
89/// Simplest [`TypeArithmetic`] implementation that defines unary / binary ops only on
90/// the Boolean type. Useful as a building block for more complex arithmetics.
91#[derive(Debug, Clone, Copy, Default)]
92pub struct BoolArithmetic;
93
94impl<Prim: WithBoolean> TypeArithmetic<Prim> for BoolArithmetic {
95    /// Processes a unary operation.
96    ///
97    /// - `!` requires a Boolean input and outputs a Boolean.
98    /// - Other operations fail with [`ErrorKind::UnsupportedFeature`].
99    fn process_unary_op<'a>(
100        &self,
101        substitutions: &mut Substitutions<Prim>,
102        context: &UnaryOpContext<Prim>,
103        mut errors: OpErrors<'_, Prim>,
104    ) -> Type<Prim> {
105        let op = context.op;
106        if op == UnaryOp::Not {
107            substitutions.unify(&Type::BOOL, &context.arg, errors);
108            Type::BOOL
109        } else {
110            let err = ErrorKind::unsupported(op);
111            errors.push(err);
112            substitutions.new_type_var()
113        }
114    }
115
116    /// Processes a binary operation.
117    ///
118    /// - `==` and `!=` require LHS and RHS to have the same type (no matter which one).
119    ///   These ops return `Bool`.
120    /// - `&&` and `||` require LHS and RHS to have `Bool` type. These ops return `Bool`.
121    /// - Other operations fail with [`ErrorKind::UnsupportedFeature`].
122    fn process_binary_op(
123        &self,
124        substitutions: &mut Substitutions<Prim>,
125        context: &BinaryOpContext<Prim>,
126        mut errors: OpErrors<'_, Prim>,
127    ) -> Type<Prim> {
128        match context.op {
129            BinaryOp::Eq | BinaryOp::NotEq => {
130                substitutions.unify(&context.lhs, &context.rhs, errors);
131                Type::BOOL
132            }
133
134            BinaryOp::And | BinaryOp::Or => {
135                substitutions.unify(
136                    &Type::BOOL,
137                    &context.lhs,
138                    errors.with_location(ErrorLocation::Lhs),
139                );
140                substitutions.unify(
141                    &Type::BOOL,
142                    &context.rhs,
143                    errors.with_location(ErrorLocation::Rhs),
144                );
145                Type::BOOL
146            }
147
148            _ => {
149                errors.push(ErrorKind::unsupported(context.op));
150                substitutions.new_type_var()
151            }
152        }
153    }
154}
155
156/// Settings for constraints placed on arguments of binary arithmetic operations.
157#[derive(Debug)]
158pub struct OpConstraintSettings<'a, Prim: PrimitiveType> {
159    /// Constraint applied to the argument of `T op Num` / `Num op T` ops.
160    pub lin: &'a dyn Constraint<Prim>,
161    /// Constraint applied to the arguments of in-kind binary arithmetic ops (`T op T`).
162    pub ops: &'a dyn Constraint<Prim>,
163}
164
165impl<Prim: PrimitiveType> Clone for OpConstraintSettings<'_, Prim> {
166    fn clone(&self) -> Self {
167        Self {
168            lin: self.lin,
169            ops: self.ops,
170        }
171    }
172}
173
174impl<Prim: PrimitiveType> Copy for OpConstraintSettings<'_, Prim> {}
175
176/// Primitive types for the numeric arithmetic: `Num`eric type and `Bool`ean.
177#[derive(Debug, Clone, Copy, PartialEq)]
178pub enum Num {
179    /// Numeric type (e.g., 1).
180    Num,
181    /// Boolean value (true or false).
182    Bool,
183}
184
185impl fmt::Display for Num {
186    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
187        formatter.write_str(match self {
188            Self::Num => "Num",
189            Self::Bool => "Bool",
190        })
191    }
192}
193
194impl FromStr for Num {
195    type Err = anyhow::Error;
196
197    fn from_str(s: &str) -> Result<Self, Self::Err> {
198        match s {
199            "Num" => Ok(Self::Num),
200            "Bool" => Ok(Self::Bool),
201            _ => Err(anyhow::anyhow!("Expected `Num` or `Bool`")),
202        }
203    }
204}
205
206impl PrimitiveType for Num {
207    fn well_known_constraints() -> ConstraintSet<Self> {
208        let mut constraints = ConstraintSet::default();
209        constraints.insert_object_safe(Linearity);
210        constraints.insert(Ops);
211        constraints
212    }
213}
214
215impl WithBoolean for Num {
216    const BOOL: Self = Self::Bool;
217}
218
219/// `Num`bers are linear, `Bool`ean values are not.
220impl LinearType for Num {
221    fn is_linear(&self) -> bool {
222        matches!(self, Self::Num)
223    }
224}
225
226/// Arithmetic on [`Num`]bers.
227///
228/// # Unary ops
229///
230/// - Unary minus is follows the equation `-T == T`, where `T` is any [linear](Linearity) type.
231/// - Unary negation is only defined for `Bool`s.
232///
233/// # Binary ops
234///
235/// Binary ops fall into 3 cases: `Num op T == T`, `T op Num == T`, or `T op T == T`,
236/// where `T` is any linear type (that is, `Num` or tuple of linear types).
237/// `T op T` is assumed by default, only falling into two other cases if one of operands
238/// is known to be a number and the other is not a number.
239///
240/// # Comparisons
241///
242/// Order comparisons (`>`, `<`, `>=`, `<=`) can be switched on or off. Use
243/// [`Self::with_comparisons()`] constructor to switch them on. If switched on, both arguments
244/// of the order comparison must be numbers.
245#[derive(Debug, Clone)]
246pub struct NumArithmetic {
247    comparisons_enabled: bool,
248}
249
250impl NumArithmetic {
251    /// Creates an instance of arithmetic that does not support order comparisons.
252    pub const fn without_comparisons() -> Self {
253        Self {
254            comparisons_enabled: false,
255        }
256    }
257
258    /// Creates an instance of arithmetic that supports order comparisons.
259    pub const fn with_comparisons() -> Self {
260        Self {
261            comparisons_enabled: true,
262        }
263    }
264
265    /// Applies [binary ops](#binary-ops) logic to unify the given LHS and RHS types.
266    /// Returns the result type of the binary operation.
267    ///
268    /// This logic can be reused by other [`TypeArithmetic`] implementations.
269    ///
270    /// # Arguments
271    ///
272    /// - `settings` are applied to arguments of arithmetic ops.
273    pub fn unify_binary_op<Prim: PrimitiveType>(
274        substitutions: &mut Substitutions<Prim>,
275        context: &BinaryOpContext<Prim>,
276        mut errors: OpErrors<'_, Prim>,
277        settings: OpConstraintSettings<'_, Prim>,
278    ) -> Type<Prim> {
279        let lhs_ty = &context.lhs;
280        let rhs_ty = &context.rhs;
281        let resolved_lhs_ty = substitutions.fast_resolve(lhs_ty);
282        let resolved_rhs_ty = substitutions.fast_resolve(rhs_ty);
283
284        match (
285            resolved_lhs_ty.is_primitive(),
286            resolved_rhs_ty.is_primitive(),
287        ) {
288            (Some(true), Some(false)) => {
289                let resolved_rhs_ty = resolved_rhs_ty.clone();
290                settings
291                    .lin
292                    .visitor(substitutions, errors.with_location(ErrorLocation::Lhs))
293                    .visit_type(lhs_ty);
294                settings
295                    .lin
296                    .visitor(substitutions, errors.with_location(ErrorLocation::Rhs))
297                    .visit_type(rhs_ty);
298                resolved_rhs_ty
299            }
300            (Some(false), Some(true)) => {
301                let resolved_lhs_ty = resolved_lhs_ty.clone();
302                settings
303                    .lin
304                    .visitor(substitutions, errors.with_location(ErrorLocation::Lhs))
305                    .visit_type(lhs_ty);
306                settings
307                    .lin
308                    .visitor(substitutions, errors.with_location(ErrorLocation::Rhs))
309                    .visit_type(rhs_ty);
310                resolved_lhs_ty
311            }
312            _ => {
313                let lhs_is_valid = errors.with_location(ErrorLocation::Lhs).check(|errors| {
314                    settings
315                        .ops
316                        .visitor(substitutions, errors)
317                        .visit_type(lhs_ty);
318                });
319                let rhs_is_valid = errors.with_location(ErrorLocation::Rhs).check(|errors| {
320                    settings
321                        .ops
322                        .visitor(substitutions, errors)
323                        .visit_type(rhs_ty);
324                });
325
326                if lhs_is_valid && rhs_is_valid {
327                    substitutions.unify(lhs_ty, rhs_ty, errors);
328                }
329                if lhs_is_valid {
330                    lhs_ty.clone()
331                } else {
332                    rhs_ty.clone()
333                }
334            }
335        }
336    }
337
338    /// Processes a unary operation according to [the numeric arithmetic rules](#unary-ops).
339    /// Returns the result type of the unary operation.
340    ///
341    /// This logic can be reused by other [`TypeArithmetic`] implementations.
342    pub fn process_unary_op<Prim: WithBoolean>(
343        substitutions: &mut Substitutions<Prim>,
344        context: &UnaryOpContext<Prim>,
345        mut errors: OpErrors<'_, Prim>,
346        constraints: &impl Constraint<Prim>,
347    ) -> Type<Prim> {
348        match context.op {
349            UnaryOp::Not => BoolArithmetic.process_unary_op(substitutions, context, errors),
350            UnaryOp::Neg => {
351                constraints
352                    .visitor(substitutions, errors)
353                    .visit_type(&context.arg);
354                context.arg.clone()
355            }
356            _ => {
357                errors.push(ErrorKind::unsupported(context.op));
358                substitutions.new_type_var()
359            }
360        }
361    }
362
363    /// Processes a binary operation according to [the numeric arithmetic rules](#binary-ops).
364    /// Returns the result type of the unary operation.
365    ///
366    /// This logic can be reused by other [`TypeArithmetic`] implementations.
367    ///
368    /// # Arguments
369    ///
370    /// - If `comparable_type` is set to `Some(_)`, it will be used to unify arguments of
371    ///   order comparisons. If `comparable_type` is `None`, order comparisons are not supported.
372    /// - `constraints` are applied to arguments of arithmetic ops.
373    pub fn process_binary_op<Prim: WithBoolean>(
374        substitutions: &mut Substitutions<Prim>,
375        context: &BinaryOpContext<Prim>,
376        mut errors: OpErrors<'_, Prim>,
377        comparable_type: Option<Prim>,
378        settings: OpConstraintSettings<'_, Prim>,
379    ) -> Type<Prim> {
380        match context.op {
381            BinaryOp::And | BinaryOp::Or | BinaryOp::Eq | BinaryOp::NotEq => {
382                BoolArithmetic.process_binary_op(substitutions, context, errors)
383            }
384
385            BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Power => {
386                Self::unify_binary_op(substitutions, context, errors, settings)
387            }
388
389            BinaryOp::Ge | BinaryOp::Le | BinaryOp::Lt | BinaryOp::Gt => {
390                if let Some(ty) = comparable_type {
391                    let ty = Type::Prim(ty);
392                    substitutions.unify(
393                        &ty,
394                        &context.lhs,
395                        errors.with_location(ErrorLocation::Lhs),
396                    );
397                    substitutions.unify(
398                        &ty,
399                        &context.rhs,
400                        errors.with_location(ErrorLocation::Rhs),
401                    );
402                } else {
403                    let err = ErrorKind::unsupported(context.op);
404                    errors.push(err);
405                }
406                Type::BOOL
407            }
408
409            _ => {
410                errors.push(ErrorKind::unsupported(context.op));
411                substitutions.new_type_var()
412            }
413        }
414    }
415}
416
417impl<Val> MapPrimitiveType<Val> for NumArithmetic
418where
419    Val: Clone + NumOps + PartialEq,
420{
421    type Prim = Num;
422
423    fn type_of_literal(&self, _: &Val) -> Self::Prim {
424        Num::Num
425    }
426}
427
428impl TypeArithmetic<Num> for NumArithmetic {
429    fn process_unary_op<'a>(
430        &self,
431        substitutions: &mut Substitutions<Num>,
432        context: &UnaryOpContext<Num>,
433        errors: OpErrors<'_, Num>,
434    ) -> Type<Num> {
435        Self::process_unary_op(substitutions, context, errors, &Linearity)
436    }
437
438    fn process_binary_op<'a>(
439        &self,
440        substitutions: &mut Substitutions<Num>,
441        context: &BinaryOpContext<Num>,
442        errors: OpErrors<'_, Num>,
443    ) -> Type<Num> {
444        const OP_SETTINGS: OpConstraintSettings<'static, Num> = OpConstraintSettings {
445            lin: &Linearity,
446            ops: &Ops,
447        };
448
449        let comparable_type = if self.comparisons_enabled {
450            Some(Num::Num)
451        } else {
452            None
453        };
454
455        Self::process_binary_op(substitutions, context, errors, comparable_type, OP_SETTINGS)
456    }
457}