aiken_lang/tipo/
infer.rs

1use super::{
2    TypeInfo, ValueConstructor, ValueConstructorVariant,
3    environment::{EntityKind, Environment},
4    error::{Error, UnifyErrorSituation, Warning},
5    expr::ExprTyper,
6    hydrator::Hydrator,
7};
8use crate::{
9    IdGenerator,
10    ast::{
11        Annotation, ArgBy, ArgName, ArgVia, DataType, Decorator, DecoratorKind, Definition,
12        Function, ModuleConstant, ModuleKind, RecordConstructor, RecordConstructorArg, Tracing,
13        TypeAlias, TypedArg, TypedDataType, TypedDefinition, TypedModule, TypedValidator,
14        UntypedArg, UntypedDefinition, UntypedModule, UntypedPattern, UntypedValidator, Use,
15        Validator,
16    },
17    expr::{TypedExpr, UntypedAssignmentKind, UntypedExpr},
18    parser::token::Token,
19    tipo::{Span, Type, TypeVar, expr::infer_function},
20};
21use std::{
22    borrow::Borrow,
23    collections::{BTreeMap, BTreeSet, HashMap},
24    fmt,
25    ops::Deref,
26    rc::Rc,
27};
28
29impl UntypedModule {
30    #[allow(clippy::too_many_arguments)]
31    #[allow(clippy::result_large_err)]
32    pub fn infer(
33        mut self,
34        id_gen: &IdGenerator,
35        kind: ModuleKind,
36        package: &str,
37        modules: &HashMap<String, TypeInfo>,
38        tracing: Tracing,
39        warnings: &mut Vec<Warning>,
40        env: Option<&str>,
41    ) -> Result<TypedModule, Error> {
42        let module_name = self.name.clone();
43        let docs = std::mem::take(&mut self.docs);
44        let mut environment =
45            Environment::new(id_gen.clone(), &module_name, &kind, modules, warnings, env);
46
47        let mut type_names = HashMap::with_capacity(self.definitions.len());
48        let mut value_names = HashMap::with_capacity(self.definitions.len());
49        let mut hydrators = HashMap::with_capacity(self.definitions.len());
50
51        // Register any modules, types, and values being imported
52        // We process imports first so that anything imported can be referenced
53        // anywhere in the module.
54        for def in self.definitions() {
55            environment.register_import(def)?;
56        }
57
58        // Register types so they can be used in constructors and functions
59        // earlier in the module.
60        environment.register_types(
61            self.definitions.iter().collect(),
62            &module_name,
63            &mut hydrators,
64            &mut type_names,
65        )?;
66
67        // Register values so they can be used in functions earlier in the module.
68        for def in self.definitions() {
69            environment.register_values(
70                def,
71                &module_name,
72                &mut hydrators,
73                &mut value_names,
74                kind,
75            )?;
76        }
77
78        // Infer the types of each definition in the module
79        // We first infer all the constants so they can be used in functions defined
80        // anywhere in the module.
81        let mut definitions = Vec::with_capacity(self.definitions.len());
82        let mut consts = vec![];
83        let mut not_consts = vec![];
84
85        for def in self.definitions().cloned() {
86            match def {
87                Definition::ModuleConstant { .. } => consts.push(def),
88                Definition::Validator { .. } if kind.is_validator() => not_consts.push(def),
89                Definition::Validator { .. } => (),
90                Definition::Fn { .. }
91                | Definition::Test { .. }
92                | Definition::Benchmark { .. }
93                | Definition::TypeAlias { .. }
94                | Definition::DataType { .. }
95                | Definition::Use { .. } => not_consts.push(def),
96            }
97        }
98
99        for def in consts.into_iter().chain(not_consts) {
100            let definition =
101                infer_definition(def, &module_name, &mut hydrators, &mut environment, tracing)?;
102
103            definitions.push(definition);
104        }
105
106        // Generalise functions now that the entire module has been inferred
107        let definitions = definitions
108            .into_iter()
109            .map(|def| environment.generalise_definition(def, &module_name))
110            .collect();
111
112        // Generate warnings for unused items
113        environment.warnings.retain(|warning| match warning {
114            Warning::UnusedVariable { location, name } => !environment
115                .validator_params
116                .contains(&(name.to_string(), *location)),
117            _ => true,
118        });
119        environment.convert_unused_to_warnings();
120
121        // Remove private and imported types and values to create the public interface
122        environment.module_values.retain(|_, info| info.public);
123
124        // Ensure no exported values have private types in their type signature
125        for value in environment.module_values.values() {
126            if let Some(leaked) = value.tipo.find_private_type() {
127                return Err(Error::PrivateTypeLeak {
128                    location: value.variant.location(),
129                    leaked_location: match &leaked {
130                        Type::App { name, .. } => {
131                            environment.module_types.get(name).map(|info| info.location)
132                        }
133                        _ => None,
134                    },
135                    leaked,
136                });
137            }
138        }
139
140        environment
141            .module_types
142            .retain(|_, info| info.public && info.module == module_name);
143
144        let own_types = environment.module_types.keys().collect::<BTreeSet<_>>();
145
146        environment
147            .module_types_constructors
148            .retain(|k, _| own_types.contains(k));
149
150        environment
151            .accessors
152            .retain(|_, accessors| accessors.public);
153
154        let Environment {
155            module_types: types,
156            module_types_constructors: types_constructors,
157            module_values: values,
158            accessors,
159            annotations,
160            ..
161        } = environment;
162
163        Ok(TypedModule {
164            docs,
165            name: module_name.clone(),
166            definitions,
167            kind,
168            lines: self.lines,
169            type_info: TypeInfo {
170                name: module_name,
171                types,
172                types_constructors,
173                values,
174                accessors,
175                annotations,
176                kind,
177                package: package.to_string(),
178            },
179        })
180    }
181}
182
183#[allow(clippy::result_large_err)]
184fn infer_definition(
185    def: UntypedDefinition,
186    module_name: &String,
187    hydrators: &mut HashMap<String, Hydrator>,
188    environment: &mut Environment<'_>,
189    tracing: Tracing,
190) -> Result<TypedDefinition, Error> {
191    match def {
192        Definition::Fn(f) => {
193            let top_level_scope = environment.open_new_scope();
194            let ret = Definition::Fn(infer_function(
195                &f,
196                module_name,
197                hydrators,
198                environment,
199                tracing,
200                &top_level_scope,
201            )?);
202            environment.close_scope(top_level_scope);
203            Ok(ret)
204        }
205
206        Definition::Validator(Validator {
207            doc,
208            location,
209            end_position,
210            handlers,
211            mut fallback,
212            params,
213            name,
214        }) => {
215            let params_length = params.len();
216
217            let top_level_scope = environment.open_new_scope();
218
219            let def = environment.in_new_scope(|environment| {
220                let fallback_name = TypedValidator::handler_name(&name, &fallback.name);
221
222                put_params_in_scope(&fallback_name, environment, &params);
223
224                let mut typed_handlers = vec![];
225
226                for mut handler in handlers {
227                    let typed_fun = environment.in_new_scope(|environment| {
228                        let temp_params = params.iter().cloned().chain(handler.arguments);
229                        handler.arguments = temp_params.collect();
230
231                        let handler_name = TypedValidator::handler_name(&name, &handler.name);
232
233                        let old_name = handler.name;
234                        handler.name = handler_name;
235
236                        let mut typed_fun = infer_function(
237                            &handler,
238                            module_name,
239                            hydrators,
240                            environment,
241                            tracing,
242                            &top_level_scope,
243                        )?;
244
245                        typed_fun.name = old_name;
246
247                        if !typed_fun.return_type.is_bool() {
248                            return Err(Error::ValidatorMustReturnBool {
249                                return_type: typed_fun.return_type.clone(),
250                                location: typed_fun.location,
251                            });
252                        }
253
254                        typed_fun.arguments.drain(0..params_length);
255
256                        if !typed_fun.has_valid_purpose_name() {
257                            return Err(Error::UnknownPurpose {
258                                location: typed_fun
259                                    .location
260                                    .map(|start, _end| (start, start + typed_fun.name.len())),
261                                available_purposes: TypedValidator::available_handler_names(),
262                            });
263                        }
264
265                        if typed_fun.arguments.len() != typed_fun.validator_arity() {
266                            return Err(Error::IncorrectValidatorArity {
267                                count: typed_fun.arguments.len() as u32,
268                                expected: typed_fun.validator_arity() as u32,
269                                location: typed_fun.location,
270                            });
271                        }
272
273                        if typed_fun.is_spend() && !typed_fun.arguments[0].tipo.is_option() {
274                            return Err(Error::CouldNotUnify {
275                                location: typed_fun.arguments[0].location,
276                                expected: Type::option(typed_fun.arguments[0].tipo.clone()),
277                                given: typed_fun.arguments[0].tipo.clone(),
278                                situation: None,
279                                rigid_type_names: Default::default(),
280                            });
281                        }
282
283                        for arg in typed_fun.arguments.iter_mut() {
284                            if arg.tipo.is_unbound() {
285                                arg.tipo = Type::data();
286                            }
287                        }
288
289                        Ok(typed_fun)
290                    })?;
291
292                    typed_handlers.push(typed_fun);
293                }
294
295                // NOTE: Duplicates are handled when registering handler names. So if we have N
296                // typed handlers, they are different. The -1 represents takes out the fallback
297                // handler name.
298                let is_exhaustive =
299                    typed_handlers.len() >= TypedValidator::available_handler_names().len() - 1;
300
301                if is_exhaustive
302                    && fallback != UntypedValidator::default_fallback(fallback.location)
303                {
304                    return Err(Error::UnexpectedValidatorFallback {
305                        fallback: fallback.location,
306                    });
307                }
308
309                let (typed_params, typed_fallback) = environment.in_new_scope(|environment| {
310                    let temp_params = params.iter().cloned().chain(fallback.arguments);
311                    fallback.arguments = temp_params.collect();
312
313                    let old_name = fallback.name;
314                    fallback.name = fallback_name;
315
316                    let mut typed_fallback = infer_function(
317                        &fallback,
318                        module_name,
319                        hydrators,
320                        environment,
321                        tracing,
322                        &top_level_scope,
323                    )?;
324
325                    typed_fallback.name = old_name;
326
327                    if !typed_fallback.return_type.is_bool() {
328                        return Err(Error::ValidatorMustReturnBool {
329                            return_type: typed_fallback.return_type.clone(),
330                            location: typed_fallback.location,
331                        });
332                    }
333
334                    let typed_params = typed_fallback
335                        .arguments
336                        .drain(0..params_length)
337                        .map(|mut arg| {
338                            if arg.tipo.is_unbound() {
339                                arg.tipo = Type::data();
340                            }
341
342                            arg
343                        })
344                        .collect();
345
346                    if typed_fallback.arguments.len() != 1 {
347                        return Err(Error::IncorrectValidatorArity {
348                            count: typed_fallback.arguments.len() as u32,
349                            expected: 1,
350                            location: typed_fallback.location,
351                        });
352                    }
353
354                    for arg in typed_fallback.arguments.iter_mut() {
355                        if arg.tipo.is_unbound() {
356                            arg.tipo = Type::data();
357                        }
358                    }
359
360                    Ok((typed_params, typed_fallback))
361                })?;
362
363                Ok(Definition::Validator(Validator {
364                    doc,
365                    end_position,
366                    handlers: typed_handlers,
367                    fallback: typed_fallback,
368                    name,
369                    location,
370                    params: typed_params,
371                }))
372            })?;
373
374            environment.close_scope(top_level_scope);
375
376            Ok(def)
377        }
378
379        Definition::Test(f) => {
380            let top_level_scope = environment.open_new_scope();
381            let (typed_via, annotation) = match f.arguments.first() {
382                Some(arg) => {
383                    if f.arguments.len() > 1 {
384                        return Err(Error::IncorrectTestArity {
385                            count: f.arguments.len(),
386                            location: f
387                                .arguments
388                                .get(1)
389                                .expect("arguments.len() > 1")
390                                .arg
391                                .location,
392                        });
393                    }
394
395                    extract_via_information(&f, arg, hydrators, environment, tracing, infer_fuzzer)
396                        .map(|(typed_via, annotation)| (Some(typed_via), annotation))
397                }
398                None => Ok((None, None)),
399            }?;
400
401            let typed_f = infer_function(
402                &f.into(),
403                module_name,
404                hydrators,
405                environment,
406                tracing,
407                &top_level_scope,
408            )?;
409
410            let is_bool = environment.unify(
411                typed_f.return_type.clone(),
412                Type::bool(),
413                typed_f.location,
414                false,
415            );
416
417            let is_void = environment.unify(
418                typed_f.return_type.clone(),
419                Type::void(),
420                typed_f.location,
421                false,
422            );
423
424            environment.close_scope(top_level_scope);
425
426            if is_bool.or(is_void).is_err() {
427                return Err(Error::IllegalTestType {
428                    location: typed_f.location,
429                });
430            }
431
432            Ok(Definition::Test(Function {
433                doc: typed_f.doc,
434                location: typed_f.location,
435                name: typed_f.name,
436                public: typed_f.public,
437                arguments: match typed_via {
438                    Some((via, tipo)) => {
439                        let arg = typed_f
440                            .arguments
441                            .first()
442                            .expect("has exactly one argument")
443                            .to_owned();
444                        vec![ArgVia {
445                            arg: TypedArg {
446                                tipo,
447                                annotation,
448                                ..arg
449                            },
450                            via,
451                        }]
452                    }
453                    None => vec![],
454                },
455                return_annotation: typed_f.return_annotation,
456                return_type: typed_f.return_type,
457                body: typed_f.body,
458                on_test_failure: typed_f.on_test_failure,
459                end_position: typed_f.end_position,
460            }))
461        }
462
463        Definition::Benchmark(f) => {
464            let top_level_scope = environment.open_new_scope();
465            let err_incorrect_arity = || {
466                Err(Error::IncorrectBenchmarkArity {
467                    location: f
468                        .location
469                        .map(|start, end| (start + Token::Benchmark.to_string().len() + 1, end)),
470                })
471            };
472
473            let (typed_via, annotation) = match f.arguments.first() {
474                None => return err_incorrect_arity(),
475                Some(arg) => {
476                    if f.arguments.len() > 1 {
477                        return err_incorrect_arity();
478                    }
479
480                    extract_via_information(&f, arg, hydrators, environment, tracing, infer_sampler)
481                }
482            }?;
483
484            let typed_f = infer_function(
485                &f.into(),
486                module_name,
487                hydrators,
488                environment,
489                tracing,
490                &top_level_scope,
491            )?;
492
493            let arguments = {
494                let arg = typed_f
495                    .arguments
496                    .first()
497                    .expect("has exactly one argument")
498                    .to_owned();
499
500                vec![ArgVia {
501                    arg: TypedArg {
502                        tipo: typed_via.1,
503                        annotation,
504                        ..arg
505                    },
506                    via: typed_via.0,
507                }]
508            };
509
510            environment.close_scope(top_level_scope);
511
512            Ok(Definition::Benchmark(Function {
513                doc: typed_f.doc,
514                location: typed_f.location,
515                name: typed_f.name,
516                public: typed_f.public,
517                arguments,
518                return_annotation: typed_f.return_annotation,
519                return_type: typed_f.return_type,
520                body: typed_f.body,
521                on_test_failure: typed_f.on_test_failure,
522                end_position: typed_f.end_position,
523            }))
524        }
525
526        Definition::TypeAlias(TypeAlias {
527            doc,
528            location,
529            public,
530            alias,
531            parameters,
532            annotation,
533            tipo: _,
534        }) => {
535            let tipo = environment
536                .get_type_constructor(&None, &alias, location)
537                .expect("Could not find existing type for type alias")
538                .tipo
539                .clone();
540
541            let typed_type_alias = TypeAlias {
542                doc,
543                location,
544                public,
545                alias,
546                parameters,
547                annotation,
548                tipo,
549            };
550
551            Ok(Definition::TypeAlias(typed_type_alias))
552        }
553
554        Definition::DataType(DataType {
555            doc,
556            location,
557            public,
558            opaque,
559            name,
560            parameters,
561            decorators,
562            constructors: untyped_constructors,
563            typed_parameters: _,
564        }) => {
565            let constructors = untyped_constructors
566                .into_iter()
567                .map(|constructor| {
568                    let preregistered_fn = environment
569                        .get_variable(&constructor.name)
570                        .expect("Could not find preregistered type for function");
571
572                    let preregistered_type = preregistered_fn.tipo.clone();
573
574                    let args = preregistered_type.function_types().map_or(
575                        Ok(vec![]),
576                        |(args_types, _return_type)| {
577                            constructor
578                                .arguments
579                                .into_iter()
580                                .zip(&args_types)
581                                .map(|(arg, t)| {
582                                    if t.is_function() {
583                                        return Err(Error::FunctionTypeInData {
584                                            location: arg.location,
585                                        });
586                                    }
587
588                                    if t.is_ml_result() {
589                                        return Err(Error::IllegalTypeInData {
590                                            location: arg.location,
591                                            tipo: t.clone(),
592                                        });
593                                    }
594
595                                    if t.contains_opaque() {
596                                        let parent = environment
597                                            .get_type_constructor_mut(&name, location)?;
598
599                                        Rc::make_mut(&mut parent.tipo).set_opaque(true)
600                                    }
601
602                                    Ok(RecordConstructorArg {
603                                        label: arg.label,
604                                        annotation: arg.annotation,
605                                        location: arg.location,
606                                        doc: arg.doc,
607                                        tipo: t.clone(),
608                                    })
609                                })
610                                .collect()
611                        },
612                    )?;
613
614                    Ok(RecordConstructor {
615                        location: constructor.location,
616                        name: constructor.name,
617                        arguments: args,
618                        decorators: constructor.decorators,
619                        doc: constructor.doc,
620                        sugar: constructor.sugar,
621                    })
622                })
623                .collect::<Result<_, Error>>()?;
624
625            let typed_parameters = environment
626                .get_type_constructor(&None, &name, location)
627                .expect("Could not find preregistered type constructor ")
628                .parameters
629                .clone();
630
631            let typed_data = DataType {
632                doc,
633                location,
634                public,
635                opaque,
636                name,
637                parameters,
638                constructors,
639                decorators,
640                typed_parameters,
641            };
642
643            for constr in &typed_data.constructors {
644                for RecordConstructorArg {
645                    tipo,
646                    location,
647                    doc: _,
648                    label: _,
649                    annotation: _,
650                } in &constr.arguments
651                {
652                    if tipo.is_function() {
653                        return Err(Error::FunctionTypeInData {
654                            location: *location,
655                        });
656                    }
657
658                    if tipo.is_ml_result() {
659                        return Err(Error::IllegalTypeInData {
660                            location: *location,
661                            tipo: tipo.clone(),
662                        });
663                    }
664                }
665            }
666
667            typed_data.check_decorators()?;
668
669            Ok(Definition::DataType(typed_data))
670        }
671
672        Definition::Use(Use {
673            location,
674            module,
675            as_name,
676            unqualified,
677            package: _,
678        }) => {
679            let module_info = environment.find_module(&module, location)?;
680
681            Ok(Definition::Use(Use {
682                location,
683                module,
684                as_name,
685                unqualified,
686                package: module_info.package.clone(),
687            }))
688        }
689
690        Definition::ModuleConstant(ModuleConstant {
691            doc,
692            location,
693            name,
694            annotation,
695            public,
696            value,
697        }) => {
698            let typed_assignment = ExprTyper::new(environment, tracing).infer_assignment(
699                UntypedPattern::Var {
700                    location,
701                    name: name.clone(),
702                },
703                value,
704                UntypedAssignmentKind::Let { backpassing: false },
705                &annotation,
706                location,
707            )?;
708
709            // NOTE: The assignment above is only a convenient way to create the TypedExpression
710            // that will be reduced at compile-time. We must increment its usage to not
711            // automatically trigger a warning since we are virtually creating a block with a
712            // single assignment that is then left unused.
713            //
714            // The usage of the constant is tracked through different means.
715            environment.increment_usage(&name);
716
717            let typed_expr = match typed_assignment {
718                TypedExpr::Assignment { value, .. } => value,
719                _ => unreachable!("infer_assignment inferred something else than an assignment?"),
720            };
721
722            let tipo = typed_expr.tipo();
723
724            if tipo.is_function() && !tipo.is_monomorphic() {
725                return Err(Error::GenericLeftAtBoundary { location });
726            }
727
728            let variant = ValueConstructor {
729                public,
730                variant: ValueConstructorVariant::ModuleConstant {
731                    location,
732                    name: name.to_owned(),
733                    module: module_name.to_owned(),
734                },
735                tipo: tipo.clone(),
736            };
737
738            environment.insert_variable(name.clone(), variant.variant.clone(), tipo.clone());
739
740            environment.insert_module_value(&name, variant);
741
742            if !public {
743                environment.init_usage(name.clone(), EntityKind::PrivateConstant, location);
744            }
745
746            Ok(Definition::ModuleConstant(ModuleConstant {
747                doc,
748                location,
749                name,
750                annotation,
751                public,
752                value: *typed_expr,
753            }))
754        }
755    }
756}
757
758#[allow(clippy::result_large_err, clippy::type_complexity)]
759fn extract_via_information<F>(
760    f: &Function<(), UntypedExpr, ArgVia<UntypedArg, UntypedExpr>>,
761    arg: &ArgVia<UntypedArg, UntypedExpr>,
762    hydrators: &mut HashMap<String, Hydrator>,
763    environment: &mut Environment<'_>,
764    tracing: Tracing,
765    infer_via: F,
766) -> Result<((TypedExpr, Rc<Type>), Option<Annotation>), Error>
767where
768    F: FnOnce(&mut Environment<'_>, Option<Rc<Type>>, &Rc<Type>, &Span) -> Result<Rc<Type>, Error>,
769{
770    let typed_via = ExprTyper::new(environment, tracing).infer(arg.via.clone())?;
771
772    let hydrator: &mut Hydrator = hydrators.get_mut(&f.name).unwrap();
773
774    let provided_inner_type = arg
775        .arg
776        .annotation
777        .as_ref()
778        .map(|ann| hydrator.type_from_annotation(ann, environment))
779        .transpose()?;
780
781    let inferred_inner_type = infer_via(
782        environment,
783        provided_inner_type.clone(),
784        &typed_via.tipo(),
785        &arg.via.location(),
786    )?;
787
788    // Ensure that the annotation, if any, matches the type inferred from the
789    // Fuzzer.
790    if let Some(provided_inner_type) = provided_inner_type {
791        environment
792            .unify(
793                inferred_inner_type.clone(),
794                provided_inner_type.clone(),
795                arg.via.location(),
796                false,
797            )
798            .map_err(|err| {
799                err.with_unify_error_situation(UnifyErrorSituation::FuzzerAnnotationMismatch)
800            })?;
801    }
802
803    // Replace the pre-registered type for the test function, to allow inferring
804    // the function body with the right type arguments.
805    let scope = environment
806        .scope
807        .get_mut(&f.name)
808        .expect("Could not find preregistered type for test");
809    if let Type::Fn {
810        ret,
811        alias,
812        args: _,
813    } = scope.tipo.as_ref()
814    {
815        scope.tipo = Rc::new(Type::Fn {
816            ret: ret.clone(),
817            args: vec![inferred_inner_type.clone()],
818            alias: alias.clone(),
819        })
820    }
821
822    Ok(((typed_via, inferred_inner_type), arg.arg.annotation.clone()))
823}
824
825#[allow(clippy::result_large_err)]
826fn infer_fuzzer(
827    environment: &mut Environment<'_>,
828    expected_inner_type: Option<Rc<Type>>,
829    tipo: &Rc<Type>,
830    location: &Span,
831) -> Result<Rc<Type>, Error> {
832    let could_not_unify = || Error::CouldNotUnify {
833        location: *location,
834        expected: Type::fuzzer(
835            expected_inner_type
836                .clone()
837                .unwrap_or_else(|| Type::generic_var(0)),
838        ),
839        given: tipo.clone(),
840        situation: None,
841        rigid_type_names: HashMap::new(),
842    };
843
844    match tipo.borrow() {
845        Type::Fn {
846            ret,
847            args: _,
848            alias: _,
849        } => match ret.borrow() {
850            Type::App {
851                module,
852                name,
853                args,
854                public: _,
855                contains_opaque: _,
856                alias: _,
857            } if module.is_empty() && name == "Option" && args.len() == 1 => {
858                match args.first().expect("args.len() == 1").borrow() {
859                    Type::Tuple { elems, .. } if elems.len() == 2 => {
860                        let wrapped = elems.get(1).expect("Tuple has two elements");
861
862                        // Disallow generics and functions as fuzzer targets. Only allow plain
863                        // concrete types.
864                        is_valid_fuzzer(wrapped, location)?;
865
866                        // NOTE: Although we've drilled through the Fuzzer structure to get here,
867                        // we still need to enforce that:
868                        //
869                        // 1. The Fuzzer is a function with a single argument of type PRNG
870                        // 2. It returns not only a wrapped type, but also a new PRNG
871                        //
872                        // All-in-all, we could bundle those verification through the
873                        // `infer_fuzzer` function, but instead, we can also just piggyback on
874                        // `unify` now that we have figured out the type carried by the fuzzer.
875                        environment.unify(
876                            tipo.clone(),
877                            Type::fuzzer(wrapped.clone()),
878                            *location,
879                            false,
880                        )?;
881
882                        Ok(wrapped.clone())
883                    }
884                    _ => Err(could_not_unify()),
885                }
886            }
887            _ => Err(could_not_unify()),
888        },
889
890        Type::Var { tipo, alias } => match &*tipo.deref().borrow() {
891            TypeVar::Link { tipo } => infer_fuzzer(
892                environment,
893                expected_inner_type,
894                &Type::with_alias(tipo.clone(), alias.clone()),
895                location,
896            ),
897            _ => Err(Error::GenericLeftAtBoundary {
898                location: *location,
899            }),
900        },
901
902        Type::App { .. } | Type::Tuple { .. } | Type::Pair { .. } => Err(could_not_unify()),
903    }
904}
905
906#[allow(clippy::result_large_err)]
907fn infer_sampler(
908    environment: &mut Environment<'_>,
909    expected_inner_type: Option<Rc<Type>>,
910    tipo: &Rc<Type>,
911    location: &Span,
912) -> Result<Rc<Type>, Error> {
913    let could_not_unify = || Error::CouldNotUnify {
914        location: *location,
915        expected: Type::sampler(
916            expected_inner_type
917                .clone()
918                .unwrap_or_else(|| Type::generic_var(0)),
919        ),
920        given: tipo.clone(),
921        situation: None,
922        rigid_type_names: HashMap::new(),
923    };
924
925    match tipo.borrow() {
926        Type::Fn {
927            ret,
928            args,
929            alias: _,
930        } => {
931            if args.len() == 1 && args[0].is_int() {
932                infer_fuzzer(environment, expected_inner_type, ret, &Span::empty())
933            } else {
934                Err(could_not_unify())
935            }
936        }
937
938        Type::Var { tipo, alias } => match &*tipo.deref().borrow() {
939            TypeVar::Link { tipo } => infer_sampler(
940                environment,
941                expected_inner_type,
942                &Type::with_alias(tipo.clone(), alias.clone()),
943                location,
944            ),
945            _ => Err(Error::GenericLeftAtBoundary {
946                location: *location,
947            }),
948        },
949
950        Type::App { .. } | Type::Tuple { .. } | Type::Pair { .. } => Err(could_not_unify()),
951    }
952}
953
954#[allow(clippy::result_large_err)]
955fn is_valid_fuzzer(tipo: &Type, location: &Span) -> Result<(), Error> {
956    match tipo {
957        Type::App {
958            name: _name,
959            module: _module,
960            args,
961            public: _,
962            contains_opaque: _,
963            alias: _,
964        } => args
965            .iter()
966            .try_for_each(|arg| is_valid_fuzzer(arg, location)),
967
968        Type::Tuple { elems, alias: _ } => elems
969            .iter()
970            .try_for_each(|arg| is_valid_fuzzer(arg, location)),
971
972        Type::Var { tipo, alias: _ } => match &*tipo.deref().borrow() {
973            TypeVar::Link { tipo } => is_valid_fuzzer(tipo, location),
974            _ => Err(Error::GenericLeftAtBoundary {
975                location: *location,
976            }),
977        },
978        Type::Fn { .. } => Err(Error::IllegalTypeInData {
979            location: *location,
980            tipo: Rc::new(tipo.clone()),
981        }),
982        Type::Pair { fst, snd, .. } => {
983            is_valid_fuzzer(fst, location)?;
984            is_valid_fuzzer(snd, location)?;
985            Ok(())
986        }
987    }
988}
989
990fn put_params_in_scope<'a>(
991    name: &'_ str,
992    environment: &'a mut Environment,
993    params: &'a [UntypedArg],
994) {
995    let preregistered_fn = environment
996        .get_variable(name)
997        .expect("Could not find preregistered type for function");
998
999    let preregistered_type = preregistered_fn.tipo.clone();
1000
1001    let (args_types, _return_type) = preregistered_type
1002        .function_types()
1003        .expect("Preregistered type for fn was not a fn");
1004
1005    for (ix, (arg, t)) in params
1006        .iter()
1007        .zip(args_types[0..params.len()].iter())
1008        .enumerate()
1009    {
1010        match arg.arg_name(ix) {
1011            ArgName::Named {
1012                name,
1013                label: _,
1014                location: _,
1015            } if arg.is_validator_param => {
1016                environment.insert_variable(
1017                    name.to_string(),
1018                    ValueConstructorVariant::LocalVariable {
1019                        location: arg.location,
1020                    },
1021                    t.clone(),
1022                );
1023
1024                if let ArgBy::ByPattern(ref pattern) = arg.by {
1025                    pattern.collect_identifiers(&mut |identifier| {
1026                        environment.validator_params.insert(identifier);
1027                    })
1028                }
1029
1030                environment.init_usage(name, EntityKind::Variable, arg.location);
1031            }
1032            ArgName::Named { .. } | ArgName::Discarded { .. } => (),
1033        };
1034    }
1035}
1036
1037#[derive(Debug, PartialEq)]
1038pub enum DecoratorContext {
1039    Record,
1040    Enum,
1041    Constructor,
1042}
1043
1044impl fmt::Display for DecoratorContext {
1045    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1046        match self {
1047            DecoratorContext::Record => write!(f, "record"),
1048            DecoratorContext::Enum => write!(f, "enum"),
1049            DecoratorContext::Constructor => write!(f, "constructor"),
1050        }
1051    }
1052}
1053
1054impl TypedDataType {
1055    #[allow(clippy::result_large_err)]
1056    fn check_decorators(&self) -> Result<(), Error> {
1057        // First determine if this is a record or enum type
1058        let is_enum = self.constructors.len() > 1;
1059
1060        let context = if is_enum {
1061            DecoratorContext::Enum
1062        } else {
1063            DecoratorContext::Record
1064        };
1065
1066        validate_decorators_in_context(&self.decorators, context, None)?;
1067
1068        let mut seen = BTreeMap::new();
1069
1070        // Validate constructor decorators
1071        for (index, constructor) in self.constructors.iter().enumerate() {
1072            validate_decorators_in_context(
1073                &constructor.decorators,
1074                DecoratorContext::Constructor,
1075                None,
1076            )?;
1077
1078            let (tag, location) = constructor
1079                .decorators
1080                .iter()
1081                .find_map(|decorator| {
1082                    if let DecoratorKind::Tag { value, .. } = &decorator.kind {
1083                        Some((value.parse().unwrap(), &decorator.location))
1084                    } else {
1085                        None
1086                    }
1087                })
1088                .unwrap_or((index, &constructor.location));
1089
1090            if let Some(first) = seen.insert(tag, location) {
1091                return Err(Error::DecoratorTagOverlap {
1092                    tag,
1093                    first: *first,
1094                    second: *location,
1095                });
1096            }
1097        }
1098
1099        Ok(())
1100    }
1101}
1102
1103#[allow(clippy::result_large_err)]
1104fn validate_decorators_in_context(
1105    decorators: &[Decorator],
1106    context: DecoratorContext,
1107    tipo: Option<&Type>,
1108) -> Result<(), Error> {
1109    // Check for conflicts between decorators
1110    for (i, d1) in decorators.iter().enumerate() {
1111        // Validate context
1112        if !d1.kind.allowed_contexts().contains(&context) {
1113            return Err(Error::DecoratorValidation {
1114                location: d1.location,
1115                message: format!("this decorator not allowed in a {context} context"),
1116            });
1117        }
1118
1119        // Validate type constraints if applicable
1120        if let Some(t) = tipo {
1121            d1.kind.validate_type(&context, t, d1.location)?;
1122        }
1123
1124        // Check for conflicts with other decorators
1125        for d2 in decorators.iter().skip(i + 1) {
1126            if d1.kind.conflicts_with(&d2.kind) {
1127                return Err(Error::ConflictingDecorators {
1128                    location: d1.location,
1129                    conflicting_location: d2.location,
1130                });
1131            }
1132        }
1133    }
1134
1135    Ok(())
1136}
1137
1138impl DecoratorKind {
1139    fn allowed_contexts(&self) -> &[DecoratorContext] {
1140        match self {
1141            DecoratorKind::Tag { .. } => &[DecoratorContext::Record, DecoratorContext::Constructor],
1142            DecoratorKind::List => &[DecoratorContext::Record],
1143        }
1144    }
1145
1146    #[allow(clippy::result_large_err)]
1147    fn validate_type(
1148        &self,
1149        _context: &DecoratorContext,
1150        _tipo: &Type,
1151        _loc: Span,
1152    ) -> Result<(), Error> {
1153        match self {
1154            DecoratorKind::Tag { .. } => Ok(()),
1155            DecoratorKind::List => Ok(()),
1156        }
1157    }
1158
1159    fn conflicts_with(&self, other: &DecoratorKind) -> bool {
1160        match (self, other) {
1161            (DecoratorKind::Tag { .. }, DecoratorKind::List) => true,
1162            (DecoratorKind::Tag { .. }, DecoratorKind::Tag { .. }) => true,
1163            (DecoratorKind::List, DecoratorKind::Tag { .. }) => true,
1164            (DecoratorKind::List, DecoratorKind::List) => true,
1165        }
1166    }
1167}