arithmetic_typing/ast/
conversion.rs

1//! Logic for converting `*Ast` types into their "main" counterparts.
2
3use nom::Err as NomErr;
4
5use std::{
6    collections::{HashMap, HashSet},
7    convert::TryFrom,
8    fmt,
9};
10
11use crate::{
12    arith::{CompleteConstraints, Constraint, ConstraintSet},
13    ast::{
14        ConstraintsAst, FunctionAst, ObjectAst, SliceAst, SpannedTypeAst, TupleAst, TupleLenAst,
15        TypeAst, TypeConstraintsAst,
16    },
17    error::{Error, Errors},
18    types::{ParamConstraints, ParamQuantifier},
19    DynConstraints, Function, Object, PrimitiveType, Slice, Tuple, Type, TypeEnvironment,
20    UnknownLen,
21};
22use arithmetic_parser::{ErrorKind as ParseErrorKind, InputSpan, NomResult, Spanned, SpannedError};
23
24/// Kinds of errors that can occur when converting `*Ast` types into their "main" counterparts.
25///
26/// During type inference, errors of this type are wrapped into the [`AstConversion`]
27/// variant of typing errors.
28///
29/// [`AstConversion`]: crate::error::ErrorKind::AstConversion
30///
31/// # Examples
32///
33/// ```
34/// use arithmetic_parser::grammars::{Parse, F32Grammar};
35/// use arithmetic_typing::{
36///     ast::AstConversionError, error::ErrorKind, Annotated, TypeEnvironment,
37/// };
38/// # use assert_matches::assert_matches;
39///
40/// # fn main() -> anyhow::Result<()> {
41/// let code = "bogus_slice: ['T; _] = (1, 2, 3);";
42/// let code = Annotated::<F32Grammar>::parse_statements(code)?;
43///
44/// let errors = TypeEnvironment::new().process_statements(&code).unwrap_err();
45/// let err = errors.into_iter().next().unwrap();
46/// assert_eq!(*err.main_span().fragment(), "'T");
47/// assert_matches!(
48///     err.kind(),
49///     ErrorKind::AstConversion(AstConversionError::FreeTypeVar(id))
50///         if id == "T"
51/// );
52/// # Ok(())
53/// # }
54/// ```
55#[derive(Debug, Clone)]
56#[non_exhaustive]
57pub enum AstConversionError {
58    /// Embedded param quantifiers.
59    EmbeddedQuantifier,
60    /// Length param not scoped by a function.
61    FreeLengthVar(String),
62    /// Type param not scoped by a function.
63    FreeTypeVar(String),
64    /// Unused length param.
65    UnusedLength(String),
66    /// Unused length param.
67    UnusedTypeParam(String),
68    /// Unknown type name.
69    UnknownType(String),
70    /// Unknown constraint.
71    UnknownConstraint(String),
72    /// Some type (`_`) encountered when parsing a standalone type.
73    ///
74    /// `_` types are only allowed in the context of a [`TypeEnvironment`]. It is a logical
75    /// error to use them when parsing standalone types.
76    InvalidSomeType,
77    /// Some length (`_`) encountered when parsing a standalone type.
78    ///
79    /// `_` lengths are only allowed in the context of a [`TypeEnvironment`]. It is a logical
80    /// error to use them when parsing standalone types.
81    InvalidSomeLength,
82    /// Field with the same name is defined multiple times in an object type.
83    DuplicateField(String),
84    /// Constraint is not object-safe.
85    NotObjectSafe(String),
86}
87
88impl fmt::Display for AstConversionError {
89    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
90        match self {
91            Self::EmbeddedQuantifier => {
92                formatter.write_str("`for` quantifier for a function that is not top-level")
93            }
94
95            Self::FreeLengthVar(name) => {
96                write!(
97                    formatter,
98                    "Length param `{}` is not scoped by function definition",
99                    name
100                )
101            }
102            Self::FreeTypeVar(name) => {
103                write!(
104                    formatter,
105                    "Type param `{}` is not scoped by function definition",
106                    name
107                )
108            }
109
110            Self::UnusedLength(name) => {
111                write!(formatter, "Unused length param `{}`", name)
112            }
113            Self::UnusedTypeParam(name) => {
114                write!(formatter, "Unused type param `{}`", name)
115            }
116            Self::UnknownType(name) => {
117                write!(formatter, "Unknown type `{}`", name)
118            }
119            Self::UnknownConstraint(name) => {
120                write!(formatter, "Unknown constraint `{}`", name)
121            }
122
123            Self::InvalidSomeType => {
124                formatter.write_str("`_` type is disallowed when parsing standalone type")
125            }
126            Self::InvalidSomeLength => {
127                formatter.write_str("`_` length is disallowed when parsing standalone type")
128            }
129
130            Self::DuplicateField(name) => {
131                write!(formatter, "Duplicate field `{}` in object type", name)
132            }
133
134            Self::NotObjectSafe(name) => {
135                write!(formatter, "Constraint `{}` is not object-safe", name)
136            }
137        }
138    }
139}
140
141impl std::error::Error for AstConversionError {}
142
143/// Intermediate conversion state.
144#[derive(Debug)]
145pub(crate) struct AstConversionState<'r, 'a, Prim: PrimitiveType> {
146    env: Option<&'r mut TypeEnvironment<Prim>>,
147    known_constraints: ConstraintSet<Prim>,
148    errors: &'r mut Errors<'a, Prim>,
149    len_params: HashMap<&'a str, usize>,
150    type_params: HashMap<&'a str, usize>,
151    is_in_function: bool,
152}
153
154impl<'r, 'a, Prim: PrimitiveType> AstConversionState<'r, 'a, Prim> {
155    pub fn new(env: &'r mut TypeEnvironment<Prim>, errors: &'r mut Errors<'a, Prim>) -> Self {
156        let known_constraints = env.known_constraints.clone();
157        Self {
158            env: Some(env),
159            known_constraints,
160            errors,
161            len_params: HashMap::new(),
162            type_params: HashMap::new(),
163            is_in_function: false,
164        }
165    }
166
167    fn without_env(errors: &'r mut Errors<'a, Prim>) -> Self {
168        Self {
169            env: None,
170            known_constraints: Prim::well_known_constraints(),
171            errors,
172            len_params: HashMap::new(),
173            type_params: HashMap::new(),
174            is_in_function: false,
175        }
176    }
177
178    fn type_param_idx(&mut self, param_name: &'a str) -> usize {
179        let type_param_count = self.type_params.len();
180        *self
181            .type_params
182            .entry(param_name)
183            .or_insert(type_param_count)
184    }
185
186    fn len_param_idx(&mut self, param_name: &'a str) -> usize {
187        let len_param_count = self.len_params.len();
188        *self.len_params.entry(param_name).or_insert(len_param_count)
189    }
190
191    fn new_type(&mut self, span: Option<&SpannedTypeAst<'a>>) -> Type<Prim> {
192        let errors = &mut *self.errors;
193        self.env.as_deref_mut().map_or_else(
194            || {
195                if let Some(span) = span {
196                    let err = AstConversionError::InvalidSomeType;
197                    errors.push(Error::conversion(err, span));
198                }
199                // We don't particularly care about the returned value; the enclosing type
200                // will be discarded anyway.
201                Type::free_var(0)
202            },
203            |env| env.substitutions.new_type_var(),
204        )
205    }
206
207    fn new_len(&mut self, span: Option<&Spanned<'a, TupleLenAst>>) -> UnknownLen {
208        let errors = &mut *self.errors;
209        self.env.as_deref_mut().map_or_else(
210            || {
211                if let Some(span) = span {
212                    let err = AstConversionError::InvalidSomeLength;
213                    errors.push(Error::conversion(err, span));
214                }
215                // We don't particularly care about the returned value; the enclosing type
216                // will be discarded anyway.
217                UnknownLen::free_var(0)
218            },
219            |env| env.substitutions.new_len_var(),
220        )
221    }
222
223    fn resolve_constraint(&self, name: &str) -> Option<(Box<dyn Constraint<Prim>>, bool)> {
224        self.known_constraints
225            .get_by_name(name)
226            .map(|(constraint, is_object_safe)| (constraint.clone_boxed(), is_object_safe))
227    }
228
229    pub(crate) fn convert_type(&mut self, ty: &SpannedTypeAst<'a>) -> Type<Prim> {
230        match &ty.extra {
231            TypeAst::Some => self.new_type(Some(ty)),
232            TypeAst::Any => Type::Any,
233            TypeAst::Dyn(constraints) => Type::Dyn(constraints.convert_dyn(self)),
234            TypeAst::Ident => {
235                let ident = *ty.fragment();
236                if let Ok(prim_type) = Prim::from_str(ident) {
237                    Type::Prim(prim_type)
238                } else {
239                    let err = AstConversionError::UnknownType(ident.to_owned());
240                    self.errors.push(Error::conversion(err, ty));
241                    self.new_type(None)
242                }
243            }
244
245            TypeAst::Param => {
246                let name = &ty.fragment()[1..];
247                if self.is_in_function {
248                    let idx = self.type_param_idx(name);
249                    Type::param(idx)
250                } else {
251                    let err = AstConversionError::FreeTypeVar(name.to_owned());
252                    self.errors.push(Error::conversion(err, ty));
253                    self.new_type(None)
254                }
255            }
256
257            TypeAst::Function(function) => self.convert_fn(function, None),
258            TypeAst::FunctionWithConstraints {
259                function,
260                constraints,
261            } => self.convert_fn(&function.extra, Some(constraints)),
262
263            TypeAst::Tuple(tuple) => tuple.convert(self).into(),
264            TypeAst::Slice(slice) => slice.convert(self).into(),
265            TypeAst::Object(object) => object.convert(self).into(),
266        }
267    }
268
269    fn convert_fn(
270        &mut self,
271        function: &FunctionAst<'a>,
272        constraints: Option<&Spanned<'a, ConstraintsAst<'a>>>,
273    ) -> Type<Prim> {
274        if self.is_in_function {
275            if let Some(constraints) = constraints {
276                let err = AstConversionError::EmbeddedQuantifier;
277                self.errors.push(Error::conversion(err, constraints));
278            }
279            function.convert(self).into()
280        } else {
281            self.is_in_function = true;
282            let mut converted_fn = function.convert(self);
283            let constraints =
284                constraints.map_or_else(ParamConstraints::default, |c| c.extra.convert(self));
285            ParamQuantifier::set_params(&mut converted_fn, constraints);
286
287            self.is_in_function = false;
288            self.type_params.clear();
289            self.len_params.clear();
290            converted_fn.into()
291        }
292    }
293}
294
295impl<'a> TypeConstraintsAst<'a> {
296    fn convert<Prim: PrimitiveType>(
297        &self,
298        state: &mut AstConversionState<'_, 'a, Prim>,
299    ) -> CompleteConstraints<Prim> {
300        self.do_convert(state, false)
301    }
302
303    fn convert_dyn<Prim: PrimitiveType>(
304        &self,
305        state: &mut AstConversionState<'_, 'a, Prim>,
306    ) -> DynConstraints<Prim> {
307        DynConstraints {
308            inner: self.do_convert(state, true),
309        }
310    }
311
312    fn do_convert<Prim: PrimitiveType>(
313        &self,
314        state: &mut AstConversionState<'_, 'a, Prim>,
315        require_object_safety: bool,
316    ) -> CompleteConstraints<Prim> {
317        let mut constraints = CompleteConstraints::default();
318        if let Some(object) = &self.object {
319            constraints.object = Some(object.convert(state));
320        }
321
322        self.terms.iter().fold(constraints, |mut acc, input| {
323            let input_str = *input.fragment();
324            if let Some((constraint, is_object_safe)) = state.resolve_constraint(input_str) {
325                if require_object_safety && !is_object_safe {
326                    let err = AstConversionError::NotObjectSafe(input_str.to_owned());
327                    state.errors.push(Error::conversion(err, input));
328                } else {
329                    acc.simple.insert_boxed(constraint);
330                }
331            } else {
332                let err = AstConversionError::UnknownConstraint(input_str.to_owned());
333                state.errors.push(Error::conversion(err, input));
334            }
335            acc
336        })
337    }
338}
339
340impl<'a> ConstraintsAst<'a> {
341    fn convert<Prim: PrimitiveType>(
342        &self,
343        state: &mut AstConversionState<'_, 'a, Prim>,
344    ) -> ParamConstraints<Prim> {
345        let mut static_lengths = HashSet::with_capacity(self.static_lengths.len());
346        for dyn_length in &self.static_lengths {
347            let name = *dyn_length.fragment();
348            if let Some(index) = state.len_params.get(name) {
349                static_lengths.insert(*index);
350            } else {
351                let err = AstConversionError::UnusedLength(name.to_owned());
352                state.errors.push(Error::conversion(err, dyn_length));
353            }
354        }
355
356        let mut type_params = HashMap::with_capacity(self.type_params.len());
357        for (param, constraints) in &self.type_params {
358            let name = *param.fragment();
359            if let Some(index) = state.type_params.get(name) {
360                type_params.insert(*index, constraints.convert(state));
361            } else {
362                let err = AstConversionError::UnusedTypeParam(name.to_owned());
363                state.errors.push(Error::conversion(err, param));
364            }
365        }
366
367        ParamConstraints {
368            type_params,
369            static_lengths,
370        }
371    }
372}
373
374impl<'a> TupleAst<'a> {
375    fn convert<Prim: PrimitiveType>(
376        &self,
377        state: &mut AstConversionState<'_, 'a, Prim>,
378    ) -> Tuple<Prim> {
379        let start = self
380            .start
381            .iter()
382            .map(|element| state.convert_type(element))
383            .collect();
384        let middle = self
385            .middle
386            .as_ref()
387            .map(|middle| middle.extra.convert(state));
388        let end = self
389            .end
390            .iter()
391            .map(|element| state.convert_type(element))
392            .collect();
393        Tuple::from_parts(start, middle, end)
394    }
395}
396
397impl<'a> SliceAst<'a> {
398    fn convert<Prim: PrimitiveType>(
399        &self,
400        state: &mut AstConversionState<'_, 'a, Prim>,
401    ) -> Slice<Prim> {
402        let element = state.convert_type(&self.element);
403
404        let converted_length = match &self.length.extra {
405            TupleLenAst::Ident => {
406                let name = *self.length.fragment();
407                if state.is_in_function {
408                    let const_param = state.len_param_idx(name);
409                    UnknownLen::param(const_param)
410                } else {
411                    let err = AstConversionError::FreeLengthVar(name.to_owned());
412                    state.errors.push(Error::conversion(err, &self.length));
413                    state.new_len(None)
414                }
415            }
416            TupleLenAst::Some => state.new_len(Some(&self.length)),
417            TupleLenAst::Dynamic => UnknownLen::Dynamic,
418        };
419
420        Slice::new(element, converted_length)
421    }
422}
423
424impl<'a> ObjectAst<'a> {
425    fn convert<Prim: PrimitiveType>(
426        &self,
427        state: &mut AstConversionState<'_, 'a, Prim>,
428    ) -> Object<Prim> {
429        let mut fields = HashMap::new();
430        for (field_name, ty) in &self.fields {
431            let field_name_str = *field_name.fragment();
432            if fields.contains_key(field_name_str) {
433                let err = AstConversionError::DuplicateField(field_name_str.to_owned());
434                state.errors.push(Error::conversion(err, field_name));
435            } else {
436                fields.insert(field_name_str.to_owned(), state.convert_type(ty));
437            }
438        }
439        Object::from_map(fields)
440    }
441}
442
443impl<'a> FunctionAst<'a> {
444    fn convert<Prim: PrimitiveType>(
445        &self,
446        state: &mut AstConversionState<'_, 'a, Prim>,
447    ) -> Function<Prim> {
448        let args = self.args.extra.convert(state);
449        let return_type = state.convert_type(&self.return_type);
450        Function::new(args, return_type)
451    }
452
453    /// Tries to convert this type into a [`Function`].
454    pub fn try_convert<Prim>(&self) -> Result<Function<Prim>, Errors<'a, Prim>>
455    where
456        Prim: PrimitiveType,
457    {
458        let mut errors = Errors::new();
459        let mut state = AstConversionState::without_env(&mut errors);
460        state.is_in_function = true;
461
462        let output = self.convert(&mut state);
463        if errors.is_empty() {
464            Ok(output)
465        } else {
466            Err(errors)
467        }
468    }
469}
470
471/// Shared parsing code for `TypeAst` and `FunctionAst`.
472fn parse_inner<'a, Ast>(
473    parser: fn(InputSpan<'a>) -> NomResult<'a, Ast>,
474    input: InputSpan<'a>,
475) -> NomResult<'a, Ast> {
476    let (rest, ast) = parser(input)?;
477    if !rest.fragment().is_empty() {
478        let err = ParseErrorKind::Leftovers.with_span(&rest.into());
479        return Err(NomErr::Failure(err));
480    }
481    Ok((rest, ast))
482}
483
484/// Shared `TryFrom<&str>` logic for `TypeAst` and `FunctionAst`.
485fn from_str<'a, Ast>(
486    parser: fn(InputSpan<'a>) -> NomResult<'a, Ast>,
487    def: &'a str,
488) -> Result<Ast, SpannedError<&'a str>> {
489    let input = InputSpan::new(def);
490    let (_, ast) = parse_inner(parser, input).map_err(|err| match err {
491        NomErr::Incomplete(_) => ParseErrorKind::Incomplete.with_span(&input.into()),
492        NomErr::Error(e) | NomErr::Failure(e) => e,
493    })?;
494    Ok(ast)
495}
496
497impl<'a> TypeAst<'a> {
498    /// Parses type AST from a string.
499    pub fn try_from(def: &'a str) -> Result<SpannedTypeAst<'a>, SpannedError<&'a str>> {
500        from_str(TypeAst::parse, def)
501    }
502}
503
504impl<'a, Prim: PrimitiveType> TryFrom<&SpannedTypeAst<'a>> for Type<Prim> {
505    type Error = Errors<'a, Prim>;
506
507    fn try_from(ast: &SpannedTypeAst<'a>) -> Result<Self, Self::Error> {
508        let mut errors = Errors::new();
509        let mut state = AstConversionState::without_env(&mut errors);
510
511        let output = state.convert_type(ast);
512        if errors.is_empty() {
513            Ok(output)
514        } else {
515            Err(errors)
516        }
517    }
518}
519
520impl<'a> TryFrom<&'a str> for FunctionAst<'a> {
521    type Error = SpannedError<&'a str>;
522
523    fn try_from(def: &'a str) -> Result<Self, Self::Error> {
524        from_str(FunctionAst::parse, def)
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use assert_matches::assert_matches;
531
532    use super::*;
533    use crate::arith::Num;
534
535    #[test]
536    fn converting_raw_fn_type() {
537        let input = InputSpan::new("(['T; N], ('T) -> Bool) -> Bool");
538        let (_, fn_type) = FunctionAst::parse(input).unwrap();
539        let fn_type = fn_type.try_convert::<Num>().unwrap();
540
541        assert_eq!(fn_type.to_string(), *input.fragment());
542    }
543
544    #[test]
545    fn converting_fn_type_with_constraint() {
546        let input = InputSpan::new("for<'T: Lin> (['T; N], ('T) -> Bool) -> Bool");
547        let (_, ast) = TypeAst::parse(input).unwrap();
548        let fn_type = <Type>::try_from(&ast).unwrap();
549
550        assert_eq!(fn_type.to_string(), *input.fragment());
551    }
552
553    #[test]
554    fn parsing_basic_types() -> anyhow::Result<()> {
555        let num_type = <Type>::try_from(&TypeAst::try_from("Num")?)?;
556        assert_eq!(num_type, Type::NUM);
557
558        let bool_type = <Type>::try_from(&TypeAst::try_from("Bool")?)?;
559        assert_eq!(bool_type, Type::BOOL);
560
561        let tuple_type = <Type>::try_from(&TypeAst::try_from("(Num, (Bool, Bool))")?)?;
562        assert_eq!(
563            tuple_type,
564            Type::from((Type::NUM, Type::Tuple(vec![Type::BOOL; 2].into()),))
565        );
566
567        let slice_type = <Type>::try_from(&TypeAst::try_from("[(Num, Bool)]")?)?;
568        let slice_type = match &slice_type {
569            Type::Tuple(tuple) => tuple.as_slice().unwrap(),
570            _ => panic!("Unexpected type: {:?}", slice_type),
571        };
572
573        assert_eq!(*slice_type.element(), Type::from((Type::NUM, Type::BOOL)));
574        assert_matches!(
575            slice_type.len().components(),
576            (Some(UnknownLen::Dynamic), 0)
577        );
578        Ok(())
579    }
580
581    #[test]
582    fn parsing_functional_type() -> anyhow::Result<()> {
583        let ty = <Type>::try_from(&TypeAst::try_from("(['T; N], ('T) -> 'U) -> 'U")?)?;
584        let ty = match ty {
585            Type::Function(fn_type) => *fn_type,
586            _ => panic!("Unexpected type: {:?}", ty),
587        };
588
589        assert_eq!(ty.params.as_ref().unwrap().len_params.len(), 1);
590        assert_eq!(ty.params.as_ref().unwrap().type_params.len(), 2);
591        assert_eq!(ty.return_type, Type::param(1));
592        Ok(())
593    }
594
595    #[test]
596    fn parsing_functional_type_with_varargs() -> anyhow::Result<()> {
597        let ty = <Type>::try_from(&TypeAst::try_from("(...[Num; N]) -> Num")?)?;
598        let ty = match ty {
599            Type::Function(fn_type) => *fn_type,
600            _ => panic!("Unexpected type: {:?}", ty),
601        };
602
603        assert_eq!(ty.params.as_ref().unwrap().len_params.len(), 1);
604        assert!(ty.params.as_ref().unwrap().type_params.is_empty());
605        let args_slice = ty.args.as_slice().unwrap();
606        assert_eq!(*args_slice.element(), Type::NUM);
607        assert_eq!(args_slice.len(), UnknownLen::param(0).into());
608        Ok(())
609    }
610
611    #[test]
612    fn parsing_incomplete_type() {
613        const INCOMPLETE_TYPES: &[&str] = &[
614            "fn(",
615            "fn(['T; ",
616            "fn(['T; N], fn(",
617            "fn(['T; N], fn('T)",
618            "fn(['T; N], fn('T)) -",
619            "fn(['T; N], fn('T)) ->",
620        ];
621
622        for &input in INCOMPLETE_TYPES {
623            // TODO: some of reported errors are difficult to interpret; should clarify.
624            TypeAst::try_from(input).unwrap_err();
625        }
626    }
627
628    #[test]
629    fn parsing_type_with_object_constraint() -> anyhow::Result<()> {
630        let type_def = "for<'T: { x: Num } + Lin> ('T) -> Bool";
631        let ty = TypeAst::try_from(type_def)?;
632        let ty = <Type>::try_from(&ty)?;
633        let ty = match ty {
634            Type::Function(fn_type) => *fn_type,
635            _ => panic!("Unexpected type: {:?}", ty),
636        };
637
638        let type_params = &ty.params.as_ref().unwrap().type_params;
639        assert_eq!(type_params.len(), 1);
640        let (_, type_params) = &type_params[0];
641        assert!(type_params.object.is_some());
642        assert!(type_params.simple.get_by_name("Lin").is_some());
643
644        assert_eq!(ty.to_string(), type_def);
645        Ok(())
646    }
647}