1use nom::Err as NomErr;
4
5use std::{
6 collections::{HashMap, HashSet},
7 convert::TryFrom,
8 fmt,
9};
10
11use crate::{
12 arith::{CompleteConstraints, Constraint, ConstraintSet},
13 ast::{
14 ConstraintsAst, FunctionAst, ObjectAst, SliceAst, SpannedTypeAst, TupleAst, TupleLenAst,
15 TypeAst, TypeConstraintsAst,
16 },
17 error::{Error, Errors},
18 types::{ParamConstraints, ParamQuantifier},
19 DynConstraints, Function, Object, PrimitiveType, Slice, Tuple, Type, TypeEnvironment,
20 UnknownLen,
21};
22use arithmetic_parser::{ErrorKind as ParseErrorKind, InputSpan, NomResult, Spanned, SpannedError};
23
24#[derive(Debug, Clone)]
56#[non_exhaustive]
57pub enum AstConversionError {
58 EmbeddedQuantifier,
60 FreeLengthVar(String),
62 FreeTypeVar(String),
64 UnusedLength(String),
66 UnusedTypeParam(String),
68 UnknownType(String),
70 UnknownConstraint(String),
72 InvalidSomeType,
77 InvalidSomeLength,
82 DuplicateField(String),
84 NotObjectSafe(String),
86}
87
88impl fmt::Display for AstConversionError {
89 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
90 match self {
91 Self::EmbeddedQuantifier => {
92 formatter.write_str("`for` quantifier for a function that is not top-level")
93 }
94
95 Self::FreeLengthVar(name) => {
96 write!(
97 formatter,
98 "Length param `{}` is not scoped by function definition",
99 name
100 )
101 }
102 Self::FreeTypeVar(name) => {
103 write!(
104 formatter,
105 "Type param `{}` is not scoped by function definition",
106 name
107 )
108 }
109
110 Self::UnusedLength(name) => {
111 write!(formatter, "Unused length param `{}`", name)
112 }
113 Self::UnusedTypeParam(name) => {
114 write!(formatter, "Unused type param `{}`", name)
115 }
116 Self::UnknownType(name) => {
117 write!(formatter, "Unknown type `{}`", name)
118 }
119 Self::UnknownConstraint(name) => {
120 write!(formatter, "Unknown constraint `{}`", name)
121 }
122
123 Self::InvalidSomeType => {
124 formatter.write_str("`_` type is disallowed when parsing standalone type")
125 }
126 Self::InvalidSomeLength => {
127 formatter.write_str("`_` length is disallowed when parsing standalone type")
128 }
129
130 Self::DuplicateField(name) => {
131 write!(formatter, "Duplicate field `{}` in object type", name)
132 }
133
134 Self::NotObjectSafe(name) => {
135 write!(formatter, "Constraint `{}` is not object-safe", name)
136 }
137 }
138 }
139}
140
141impl std::error::Error for AstConversionError {}
142
143#[derive(Debug)]
145pub(crate) struct AstConversionState<'r, 'a, Prim: PrimitiveType> {
146 env: Option<&'r mut TypeEnvironment<Prim>>,
147 known_constraints: ConstraintSet<Prim>,
148 errors: &'r mut Errors<'a, Prim>,
149 len_params: HashMap<&'a str, usize>,
150 type_params: HashMap<&'a str, usize>,
151 is_in_function: bool,
152}
153
154impl<'r, 'a, Prim: PrimitiveType> AstConversionState<'r, 'a, Prim> {
155 pub fn new(env: &'r mut TypeEnvironment<Prim>, errors: &'r mut Errors<'a, Prim>) -> Self {
156 let known_constraints = env.known_constraints.clone();
157 Self {
158 env: Some(env),
159 known_constraints,
160 errors,
161 len_params: HashMap::new(),
162 type_params: HashMap::new(),
163 is_in_function: false,
164 }
165 }
166
167 fn without_env(errors: &'r mut Errors<'a, Prim>) -> Self {
168 Self {
169 env: None,
170 known_constraints: Prim::well_known_constraints(),
171 errors,
172 len_params: HashMap::new(),
173 type_params: HashMap::new(),
174 is_in_function: false,
175 }
176 }
177
178 fn type_param_idx(&mut self, param_name: &'a str) -> usize {
179 let type_param_count = self.type_params.len();
180 *self
181 .type_params
182 .entry(param_name)
183 .or_insert(type_param_count)
184 }
185
186 fn len_param_idx(&mut self, param_name: &'a str) -> usize {
187 let len_param_count = self.len_params.len();
188 *self.len_params.entry(param_name).or_insert(len_param_count)
189 }
190
191 fn new_type(&mut self, span: Option<&SpannedTypeAst<'a>>) -> Type<Prim> {
192 let errors = &mut *self.errors;
193 self.env.as_deref_mut().map_or_else(
194 || {
195 if let Some(span) = span {
196 let err = AstConversionError::InvalidSomeType;
197 errors.push(Error::conversion(err, span));
198 }
199 Type::free_var(0)
202 },
203 |env| env.substitutions.new_type_var(),
204 )
205 }
206
207 fn new_len(&mut self, span: Option<&Spanned<'a, TupleLenAst>>) -> UnknownLen {
208 let errors = &mut *self.errors;
209 self.env.as_deref_mut().map_or_else(
210 || {
211 if let Some(span) = span {
212 let err = AstConversionError::InvalidSomeLength;
213 errors.push(Error::conversion(err, span));
214 }
215 UnknownLen::free_var(0)
218 },
219 |env| env.substitutions.new_len_var(),
220 )
221 }
222
223 fn resolve_constraint(&self, name: &str) -> Option<(Box<dyn Constraint<Prim>>, bool)> {
224 self.known_constraints
225 .get_by_name(name)
226 .map(|(constraint, is_object_safe)| (constraint.clone_boxed(), is_object_safe))
227 }
228
229 pub(crate) fn convert_type(&mut self, ty: &SpannedTypeAst<'a>) -> Type<Prim> {
230 match &ty.extra {
231 TypeAst::Some => self.new_type(Some(ty)),
232 TypeAst::Any => Type::Any,
233 TypeAst::Dyn(constraints) => Type::Dyn(constraints.convert_dyn(self)),
234 TypeAst::Ident => {
235 let ident = *ty.fragment();
236 if let Ok(prim_type) = Prim::from_str(ident) {
237 Type::Prim(prim_type)
238 } else {
239 let err = AstConversionError::UnknownType(ident.to_owned());
240 self.errors.push(Error::conversion(err, ty));
241 self.new_type(None)
242 }
243 }
244
245 TypeAst::Param => {
246 let name = &ty.fragment()[1..];
247 if self.is_in_function {
248 let idx = self.type_param_idx(name);
249 Type::param(idx)
250 } else {
251 let err = AstConversionError::FreeTypeVar(name.to_owned());
252 self.errors.push(Error::conversion(err, ty));
253 self.new_type(None)
254 }
255 }
256
257 TypeAst::Function(function) => self.convert_fn(function, None),
258 TypeAst::FunctionWithConstraints {
259 function,
260 constraints,
261 } => self.convert_fn(&function.extra, Some(constraints)),
262
263 TypeAst::Tuple(tuple) => tuple.convert(self).into(),
264 TypeAst::Slice(slice) => slice.convert(self).into(),
265 TypeAst::Object(object) => object.convert(self).into(),
266 }
267 }
268
269 fn convert_fn(
270 &mut self,
271 function: &FunctionAst<'a>,
272 constraints: Option<&Spanned<'a, ConstraintsAst<'a>>>,
273 ) -> Type<Prim> {
274 if self.is_in_function {
275 if let Some(constraints) = constraints {
276 let err = AstConversionError::EmbeddedQuantifier;
277 self.errors.push(Error::conversion(err, constraints));
278 }
279 function.convert(self).into()
280 } else {
281 self.is_in_function = true;
282 let mut converted_fn = function.convert(self);
283 let constraints =
284 constraints.map_or_else(ParamConstraints::default, |c| c.extra.convert(self));
285 ParamQuantifier::set_params(&mut converted_fn, constraints);
286
287 self.is_in_function = false;
288 self.type_params.clear();
289 self.len_params.clear();
290 converted_fn.into()
291 }
292 }
293}
294
295impl<'a> TypeConstraintsAst<'a> {
296 fn convert<Prim: PrimitiveType>(
297 &self,
298 state: &mut AstConversionState<'_, 'a, Prim>,
299 ) -> CompleteConstraints<Prim> {
300 self.do_convert(state, false)
301 }
302
303 fn convert_dyn<Prim: PrimitiveType>(
304 &self,
305 state: &mut AstConversionState<'_, 'a, Prim>,
306 ) -> DynConstraints<Prim> {
307 DynConstraints {
308 inner: self.do_convert(state, true),
309 }
310 }
311
312 fn do_convert<Prim: PrimitiveType>(
313 &self,
314 state: &mut AstConversionState<'_, 'a, Prim>,
315 require_object_safety: bool,
316 ) -> CompleteConstraints<Prim> {
317 let mut constraints = CompleteConstraints::default();
318 if let Some(object) = &self.object {
319 constraints.object = Some(object.convert(state));
320 }
321
322 self.terms.iter().fold(constraints, |mut acc, input| {
323 let input_str = *input.fragment();
324 if let Some((constraint, is_object_safe)) = state.resolve_constraint(input_str) {
325 if require_object_safety && !is_object_safe {
326 let err = AstConversionError::NotObjectSafe(input_str.to_owned());
327 state.errors.push(Error::conversion(err, input));
328 } else {
329 acc.simple.insert_boxed(constraint);
330 }
331 } else {
332 let err = AstConversionError::UnknownConstraint(input_str.to_owned());
333 state.errors.push(Error::conversion(err, input));
334 }
335 acc
336 })
337 }
338}
339
340impl<'a> ConstraintsAst<'a> {
341 fn convert<Prim: PrimitiveType>(
342 &self,
343 state: &mut AstConversionState<'_, 'a, Prim>,
344 ) -> ParamConstraints<Prim> {
345 let mut static_lengths = HashSet::with_capacity(self.static_lengths.len());
346 for dyn_length in &self.static_lengths {
347 let name = *dyn_length.fragment();
348 if let Some(index) = state.len_params.get(name) {
349 static_lengths.insert(*index);
350 } else {
351 let err = AstConversionError::UnusedLength(name.to_owned());
352 state.errors.push(Error::conversion(err, dyn_length));
353 }
354 }
355
356 let mut type_params = HashMap::with_capacity(self.type_params.len());
357 for (param, constraints) in &self.type_params {
358 let name = *param.fragment();
359 if let Some(index) = state.type_params.get(name) {
360 type_params.insert(*index, constraints.convert(state));
361 } else {
362 let err = AstConversionError::UnusedTypeParam(name.to_owned());
363 state.errors.push(Error::conversion(err, param));
364 }
365 }
366
367 ParamConstraints {
368 type_params,
369 static_lengths,
370 }
371 }
372}
373
374impl<'a> TupleAst<'a> {
375 fn convert<Prim: PrimitiveType>(
376 &self,
377 state: &mut AstConversionState<'_, 'a, Prim>,
378 ) -> Tuple<Prim> {
379 let start = self
380 .start
381 .iter()
382 .map(|element| state.convert_type(element))
383 .collect();
384 let middle = self
385 .middle
386 .as_ref()
387 .map(|middle| middle.extra.convert(state));
388 let end = self
389 .end
390 .iter()
391 .map(|element| state.convert_type(element))
392 .collect();
393 Tuple::from_parts(start, middle, end)
394 }
395}
396
397impl<'a> SliceAst<'a> {
398 fn convert<Prim: PrimitiveType>(
399 &self,
400 state: &mut AstConversionState<'_, 'a, Prim>,
401 ) -> Slice<Prim> {
402 let element = state.convert_type(&self.element);
403
404 let converted_length = match &self.length.extra {
405 TupleLenAst::Ident => {
406 let name = *self.length.fragment();
407 if state.is_in_function {
408 let const_param = state.len_param_idx(name);
409 UnknownLen::param(const_param)
410 } else {
411 let err = AstConversionError::FreeLengthVar(name.to_owned());
412 state.errors.push(Error::conversion(err, &self.length));
413 state.new_len(None)
414 }
415 }
416 TupleLenAst::Some => state.new_len(Some(&self.length)),
417 TupleLenAst::Dynamic => UnknownLen::Dynamic,
418 };
419
420 Slice::new(element, converted_length)
421 }
422}
423
424impl<'a> ObjectAst<'a> {
425 fn convert<Prim: PrimitiveType>(
426 &self,
427 state: &mut AstConversionState<'_, 'a, Prim>,
428 ) -> Object<Prim> {
429 let mut fields = HashMap::new();
430 for (field_name, ty) in &self.fields {
431 let field_name_str = *field_name.fragment();
432 if fields.contains_key(field_name_str) {
433 let err = AstConversionError::DuplicateField(field_name_str.to_owned());
434 state.errors.push(Error::conversion(err, field_name));
435 } else {
436 fields.insert(field_name_str.to_owned(), state.convert_type(ty));
437 }
438 }
439 Object::from_map(fields)
440 }
441}
442
443impl<'a> FunctionAst<'a> {
444 fn convert<Prim: PrimitiveType>(
445 &self,
446 state: &mut AstConversionState<'_, 'a, Prim>,
447 ) -> Function<Prim> {
448 let args = self.args.extra.convert(state);
449 let return_type = state.convert_type(&self.return_type);
450 Function::new(args, return_type)
451 }
452
453 pub fn try_convert<Prim>(&self) -> Result<Function<Prim>, Errors<'a, Prim>>
455 where
456 Prim: PrimitiveType,
457 {
458 let mut errors = Errors::new();
459 let mut state = AstConversionState::without_env(&mut errors);
460 state.is_in_function = true;
461
462 let output = self.convert(&mut state);
463 if errors.is_empty() {
464 Ok(output)
465 } else {
466 Err(errors)
467 }
468 }
469}
470
471fn parse_inner<'a, Ast>(
473 parser: fn(InputSpan<'a>) -> NomResult<'a, Ast>,
474 input: InputSpan<'a>,
475) -> NomResult<'a, Ast> {
476 let (rest, ast) = parser(input)?;
477 if !rest.fragment().is_empty() {
478 let err = ParseErrorKind::Leftovers.with_span(&rest.into());
479 return Err(NomErr::Failure(err));
480 }
481 Ok((rest, ast))
482}
483
484fn from_str<'a, Ast>(
486 parser: fn(InputSpan<'a>) -> NomResult<'a, Ast>,
487 def: &'a str,
488) -> Result<Ast, SpannedError<&'a str>> {
489 let input = InputSpan::new(def);
490 let (_, ast) = parse_inner(parser, input).map_err(|err| match err {
491 NomErr::Incomplete(_) => ParseErrorKind::Incomplete.with_span(&input.into()),
492 NomErr::Error(e) | NomErr::Failure(e) => e,
493 })?;
494 Ok(ast)
495}
496
497impl<'a> TypeAst<'a> {
498 pub fn try_from(def: &'a str) -> Result<SpannedTypeAst<'a>, SpannedError<&'a str>> {
500 from_str(TypeAst::parse, def)
501 }
502}
503
504impl<'a, Prim: PrimitiveType> TryFrom<&SpannedTypeAst<'a>> for Type<Prim> {
505 type Error = Errors<'a, Prim>;
506
507 fn try_from(ast: &SpannedTypeAst<'a>) -> Result<Self, Self::Error> {
508 let mut errors = Errors::new();
509 let mut state = AstConversionState::without_env(&mut errors);
510
511 let output = state.convert_type(ast);
512 if errors.is_empty() {
513 Ok(output)
514 } else {
515 Err(errors)
516 }
517 }
518}
519
520impl<'a> TryFrom<&'a str> for FunctionAst<'a> {
521 type Error = SpannedError<&'a str>;
522
523 fn try_from(def: &'a str) -> Result<Self, Self::Error> {
524 from_str(FunctionAst::parse, def)
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use assert_matches::assert_matches;
531
532 use super::*;
533 use crate::arith::Num;
534
535 #[test]
536 fn converting_raw_fn_type() {
537 let input = InputSpan::new("(['T; N], ('T) -> Bool) -> Bool");
538 let (_, fn_type) = FunctionAst::parse(input).unwrap();
539 let fn_type = fn_type.try_convert::<Num>().unwrap();
540
541 assert_eq!(fn_type.to_string(), *input.fragment());
542 }
543
544 #[test]
545 fn converting_fn_type_with_constraint() {
546 let input = InputSpan::new("for<'T: Lin> (['T; N], ('T) -> Bool) -> Bool");
547 let (_, ast) = TypeAst::parse(input).unwrap();
548 let fn_type = <Type>::try_from(&ast).unwrap();
549
550 assert_eq!(fn_type.to_string(), *input.fragment());
551 }
552
553 #[test]
554 fn parsing_basic_types() -> anyhow::Result<()> {
555 let num_type = <Type>::try_from(&TypeAst::try_from("Num")?)?;
556 assert_eq!(num_type, Type::NUM);
557
558 let bool_type = <Type>::try_from(&TypeAst::try_from("Bool")?)?;
559 assert_eq!(bool_type, Type::BOOL);
560
561 let tuple_type = <Type>::try_from(&TypeAst::try_from("(Num, (Bool, Bool))")?)?;
562 assert_eq!(
563 tuple_type,
564 Type::from((Type::NUM, Type::Tuple(vec![Type::BOOL; 2].into()),))
565 );
566
567 let slice_type = <Type>::try_from(&TypeAst::try_from("[(Num, Bool)]")?)?;
568 let slice_type = match &slice_type {
569 Type::Tuple(tuple) => tuple.as_slice().unwrap(),
570 _ => panic!("Unexpected type: {:?}", slice_type),
571 };
572
573 assert_eq!(*slice_type.element(), Type::from((Type::NUM, Type::BOOL)));
574 assert_matches!(
575 slice_type.len().components(),
576 (Some(UnknownLen::Dynamic), 0)
577 );
578 Ok(())
579 }
580
581 #[test]
582 fn parsing_functional_type() -> anyhow::Result<()> {
583 let ty = <Type>::try_from(&TypeAst::try_from("(['T; N], ('T) -> 'U) -> 'U")?)?;
584 let ty = match ty {
585 Type::Function(fn_type) => *fn_type,
586 _ => panic!("Unexpected type: {:?}", ty),
587 };
588
589 assert_eq!(ty.params.as_ref().unwrap().len_params.len(), 1);
590 assert_eq!(ty.params.as_ref().unwrap().type_params.len(), 2);
591 assert_eq!(ty.return_type, Type::param(1));
592 Ok(())
593 }
594
595 #[test]
596 fn parsing_functional_type_with_varargs() -> anyhow::Result<()> {
597 let ty = <Type>::try_from(&TypeAst::try_from("(...[Num; N]) -> Num")?)?;
598 let ty = match ty {
599 Type::Function(fn_type) => *fn_type,
600 _ => panic!("Unexpected type: {:?}", ty),
601 };
602
603 assert_eq!(ty.params.as_ref().unwrap().len_params.len(), 1);
604 assert!(ty.params.as_ref().unwrap().type_params.is_empty());
605 let args_slice = ty.args.as_slice().unwrap();
606 assert_eq!(*args_slice.element(), Type::NUM);
607 assert_eq!(args_slice.len(), UnknownLen::param(0).into());
608 Ok(())
609 }
610
611 #[test]
612 fn parsing_incomplete_type() {
613 const INCOMPLETE_TYPES: &[&str] = &[
614 "fn(",
615 "fn(['T; ",
616 "fn(['T; N], fn(",
617 "fn(['T; N], fn('T)",
618 "fn(['T; N], fn('T)) -",
619 "fn(['T; N], fn('T)) ->",
620 ];
621
622 for &input in INCOMPLETE_TYPES {
623 TypeAst::try_from(input).unwrap_err();
625 }
626 }
627
628 #[test]
629 fn parsing_type_with_object_constraint() -> anyhow::Result<()> {
630 let type_def = "for<'T: { x: Num } + Lin> ('T) -> Bool";
631 let ty = TypeAst::try_from(type_def)?;
632 let ty = <Type>::try_from(&ty)?;
633 let ty = match ty {
634 Type::Function(fn_type) => *fn_type,
635 _ => panic!("Unexpected type: {:?}", ty),
636 };
637
638 let type_params = &ty.params.as_ref().unwrap().type_params;
639 assert_eq!(type_params.len(), 1);
640 let (_, type_params) = &type_params[0];
641 assert!(type_params.object.is_some());
642 assert!(type_params.simple.get_by_name("Lin").is_some());
643
644 assert_eq!(ty.to_string(), type_def);
645 Ok(())
646 }
647}