Skip to main content

nickel_lang_parser/ast/
mod.rs

1//! The Nickel AST, as ingested by the (future) bytecode compiler.
2//!
3//! Since the AST is built once for each Nickel expression and is then compiled away to bytecode,
4//! the total number of allocated nodes is reasonably bounded by the input program size. Thus, for
5//! performance reasons, we allocate notes using an arena and keep them alive until the end of
6//! compilation. In return, we get fast allocation and de-allocation, and we can easily reference
7//! other nodes and data structures using native references.
8//!
9//! The corresponding lifetime of all the nodes - and thus of the arena as well - is consistently
10//! called `'ast`.
11
12use std::{cmp::Ordering, ffi::OsStr, fmt, path::Path};
13
14use malachite::base::num::basic::traits::Zero;
15
16use crate::{
17    error::ParseError,
18    identifier::{Ident, LocIdent},
19    impl_display_from_bytecode_pretty,
20    position::TermPos,
21    traverse::*,
22};
23
24/// The underlying type representing Nickel numbers. Currently, numbers are arbitrary precision
25/// rationals.
26///
27/// Basic arithmetic operations are exact, without loss of precision (within the limits of available
28/// memory).
29///
30/// Raising to a power that doesn't fit in a signed 64bits number will lead to converting both
31/// operands to 64-bits floats, performing the floating-point power operation, and converting back
32/// to rationals, which can incur a loss of precision.
33///
34/// [^number-serialization]: Conversion to string and serialization try to first convert the
35///     rational as an exact signed or usigned 64-bits integer. If this succeeds, such operations
36///     don't lose precision. Otherwise, the number is converted to the nearest 64bit float and then
37///     serialized/printed, which can incur a loss of information.
38pub type Number = malachite::rational::Rational;
39
40pub mod alloc;
41pub mod builder;
42pub mod combine;
43pub mod pattern;
44pub mod pretty;
45pub mod primop;
46pub mod record;
47pub mod typ;
48
49pub use alloc::AstAlloc;
50use pattern::*;
51use primop::PrimOp;
52use record::*;
53use serde::{Serialize, Serializer};
54use typ::*;
55
56/// Supported input formats.
57#[derive(Default, Clone, Copy, Eq, Debug, PartialEq, Hash)]
58#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
59pub enum InputFormat {
60    #[default]
61    Nickel,
62    Json,
63    Yaml,
64    Toml,
65    #[cfg(feature = "nix-experimental")]
66    Nix,
67    Text,
68}
69
70impl InputFormat {
71    /// Returns an [InputFormat] based on the file extension of a path.
72    pub fn from_path(path: impl AsRef<Path>) -> Option<InputFormat> {
73        match path.as_ref().extension().and_then(OsStr::to_str) {
74            Some("ncl") => Some(InputFormat::Nickel),
75            Some("json") => Some(InputFormat::Json),
76            Some("yaml") | Some("yml") => Some(InputFormat::Yaml),
77            Some("toml") => Some(InputFormat::Toml),
78            #[cfg(feature = "nix-experimental")]
79            Some("nix") => Some(InputFormat::Nix),
80            Some("txt") => Some(InputFormat::Text),
81            _ => None,
82        }
83    }
84
85    pub fn to_str(&self) -> &'static str {
86        match self {
87            InputFormat::Nickel => "Nickel",
88            InputFormat::Json => "Json",
89            InputFormat::Yaml => "Yaml",
90            InputFormat::Toml => "Toml",
91            InputFormat::Text => "Text",
92            #[cfg(feature = "nix-experimental")]
93            InputFormat::Nix => "Nix",
94        }
95    }
96}
97
98impl fmt::Display for InputFormat {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        write!(f, "{}", self.to_str())
101    }
102}
103
104impl std::str::FromStr for InputFormat {
105    type Err = ();
106
107    fn from_str(s: &str) -> Result<Self, Self::Err> {
108        Ok(match s {
109            "Json" => InputFormat::Json,
110            "Nickel" => InputFormat::Nickel,
111            "Text" => InputFormat::Text,
112            "Yaml" => InputFormat::Yaml,
113            "Toml" => InputFormat::Toml,
114            #[cfg(feature = "nix-experimental")]
115            "Nix" => InputFormat::Nix,
116            _ => return Err(()),
117        })
118    }
119}
120
121/// A chunk of a string with interpolated expressions inside. Can be either a string literal or an
122/// interpolated expression.
123#[derive(Debug, PartialEq, Eq, Clone)]
124pub enum StringChunk<E> {
125    /// A string literal.
126    Literal(String),
127
128    /// An interpolated expression.
129    Expr(
130        E,     /* the expression */
131        usize, /* the indentation level (see parser::utils::strip_indent) */
132    ),
133}
134
135impl<E> StringChunk<E> {
136    // This is only for use in tests, but because some tests are in other crates
137    // it can't be cfg(test).
138    #[doc(hidden)]
139    pub fn expr(e: E) -> Self {
140        StringChunk::Expr(e, 0)
141    }
142
143    pub fn try_chunks_as_static_str<'a, I>(chunks: I) -> Option<String>
144    where
145        I: IntoIterator<Item = &'a StringChunk<E>>,
146        E: 'a,
147    {
148        chunks
149            .into_iter()
150            .try_fold(String::new(), |mut acc, next| match next {
151                StringChunk::Literal(lit) => {
152                    acc.push_str(lit);
153                    Some(acc)
154                }
155                _ => None,
156            })
157    }
158}
159#[derive(Debug, Clone, Default)]
160pub enum MergePriority {
161    /// The priority of default values that are overridden by everything else.
162    Bottom,
163
164    /// The priority by default, when no priority annotation (`default`, `force`, `priority`) is
165    /// provided.
166    ///
167    /// Act as the value `MergePriority::Numeral(0)` with respect to ordering and equality
168    /// testing. The only way to discriminate this variant is to pattern match on it.
169    #[default]
170    Neutral,
171
172    /// A numeral priority.
173    Numeral(Number),
174
175    /// The priority of values that override everything else and can't be overridden.
176    Top,
177}
178
179impl PartialOrd for MergePriority {
180    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
181        Some(self.cmp(other))
182    }
183}
184
185impl PartialEq for MergePriority {
186    fn eq(&self, other: &Self) -> bool {
187        match (self, other) {
188            (MergePriority::Bottom, MergePriority::Bottom)
189            | (MergePriority::Neutral, MergePriority::Neutral)
190            | (MergePriority::Top, MergePriority::Top) => true,
191            (MergePriority::Numeral(p1), MergePriority::Numeral(p2)) => p1 == p2,
192            (MergePriority::Neutral, MergePriority::Numeral(p))
193            | (MergePriority::Numeral(p), MergePriority::Neutral)
194                if p == &Number::ZERO =>
195            {
196                true
197            }
198            _ => false,
199        }
200    }
201}
202
203impl Eq for MergePriority {}
204
205impl Ord for MergePriority {
206    fn cmp(&self, other: &Self) -> Ordering {
207        match (self, other) {
208            // Equalities
209            (MergePriority::Bottom, MergePriority::Bottom)
210            | (MergePriority::Top, MergePriority::Top)
211            | (MergePriority::Neutral, MergePriority::Neutral) => Ordering::Equal,
212            (MergePriority::Numeral(p1), MergePriority::Numeral(p2)) => p1.cmp(p2),
213
214            // Top and bottom.
215            (MergePriority::Bottom, _) | (_, MergePriority::Top) => Ordering::Less,
216            (MergePriority::Top, _) | (_, MergePriority::Bottom) => Ordering::Greater,
217
218            // Neutral and numeral.
219            (MergePriority::Neutral, MergePriority::Numeral(n)) => Number::ZERO.cmp(n),
220            (MergePriority::Numeral(n), MergePriority::Neutral) => n.cmp(&Number::ZERO),
221        }
222    }
223}
224
225impl fmt::Display for MergePriority {
226    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
227        match self {
228            MergePriority::Bottom => write!(f, "default"),
229            MergePriority::Neutral => write!(f, "{}", Number::ZERO),
230            MergePriority::Numeral(p) => write!(f, "{p}"),
231            MergePriority::Top => write!(f, "force"),
232        }
233    }
234}
235
236impl Serialize for MergePriority {
237    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
238    where
239        S: Serializer,
240    {
241        serializer.serialize_str(&self.to_string())
242    }
243}
244
245/// Possible origins of a merge operation.
246#[derive(Clone, Copy, Eq, PartialEq, Debug, Default)]
247pub enum MergeKind {
248    /// A standard, user-written merge operation (or a merge operation descending from a
249    /// user-written merge operation).
250    #[default]
251    Standard,
252    /// A merge generated by the parser when reconstructing piecewise definition, for example:
253    ///
254    /// ```nickel
255    /// { foo = def1, foo = def2 }
256    /// ```
257    PiecewiseDef,
258}
259
260/// A flavor for record operations. By design, we want empty optional values to be transparent for
261/// record operations, because they would otherwise make many operations fail spuriously (e.g.
262/// trying to map over such an empty value). So they are most of the time silently ignored.
263///
264/// However, it's sometimes useful and even necessary to take them into account. This behavior is
265/// controlled by [RecordOpKind].
266#[derive(Clone, Debug, PartialEq, Eq, Copy, Default)]
267pub enum RecordOpKind {
268    #[default]
269    IgnoreEmptyOpt,
270    ConsiderAllFields,
271}
272
273/// A node of the Nickel AST.
274///
275/// Nodes are built by the parser and then mostly traversed immutably. Such nodes are optimized for
276/// sharing (hence immutability) and for size, as the size of an enum can grow quite quickly in
277/// Rust. In particular, any data that is bigger than a few words isn't usually owned but rather a
278/// reference to some arena-allocated data
279///
280/// Using an arena has another advantage: the data is allocated in the same order as the AST is
281/// built. This means that even if there are reference indirections, the children of a node are
282/// most likely close to the node itself in memory, which should be good for cache locality.
283#[derive(Clone, Debug, PartialEq, Eq, Default)]
284pub enum Node<'ast> {
285    /// The null value.
286    #[default]
287    Null,
288
289    /// A boolean value.
290    Bool(bool),
291
292    /// A number.
293    ///
294    /// A number is an arbitrary-precision rational in Nickel. It's not small and thus we put it
295    /// behind a reference to avoid size bloat.
296    Number(&'ast Number),
297
298    /// A string literal.
299    String(&'ast str),
300
301    /// A string containing interpolated expressions, represented as a list of either literals or
302    /// expressions.
303    ///
304    /// As opposed to `nickel_lang_core::term::Term::StrChunks`, the chunks are stored in the original order:
305    /// `"hello%{var}"` will give `["hello", var]`.
306    StringChunks(&'ast [StringChunk<Ast<'ast>>]),
307
308    /// A function.
309    Fun {
310        args: &'ast [Pattern<'ast>],
311        body: &'ast Ast<'ast>,
312    },
313
314    /// A let block.
315    Let {
316        bindings: &'ast [LetBinding<'ast>],
317        body: &'ast Ast<'ast>,
318        rec: bool,
319    },
320
321    /// An application to one or more arguments.
322    App {
323        head: &'ast Ast<'ast>,
324        args: &'ast [Ast<'ast>],
325    },
326
327    /// A variable.
328    Var(LocIdent),
329
330    /// An enum variant (an algebraic datatype).
331    ///
332    /// Variants have at most one argument: variants with no arguments are often called simply
333    /// enum tags. Note that one can just use a record as an argument to emulate variants with
334    /// multiple arguments.
335    EnumVariant {
336        tag: LocIdent,
337        arg: Option<&'ast Ast<'ast>>,
338    },
339
340    /// A record.
341    Record(&'ast Record<'ast>),
342
343    /// An if-then-else expression.
344    IfThenElse {
345        cond: &'ast Ast<'ast>,
346        then_branch: &'ast Ast<'ast>,
347        else_branch: &'ast Ast<'ast>,
348    },
349
350    /// A match expression. This expression is still to be applied to an argument to match on.
351    Match(Match<'ast>),
352
353    /// An array.
354    Array(&'ast [Ast<'ast>]),
355
356    /// An n-ary primitive operation application. As opposed to a traditional function application:
357    ///
358    /// 1. The function part is necessarily a primitive operation.
359    /// 2. The arguments are forced before entering the
360    PrimOpApp {
361        op: &'ast PrimOp,
362        args: &'ast [Ast<'ast>],
363    },
364
365    /// A term with a type and/or contract annotation.
366    Annotated {
367        annot: &'ast Annotation<'ast>,
368        inner: &'ast Ast<'ast>,
369    },
370
371    /// An import.
372    Import(Import<'ast>),
373
374    /// A type in term position, such as in `let my_contract = Number -> Number in ...`.
375    ///
376    /// During evaluation, this will get turned into a contract.
377    Type(&'ast Type<'ast>),
378
379    /// A term that couldn't be parsed properly. Used by the LSP to handle partially valid
380    /// programs.
381    ParseError(&'ast ParseError),
382}
383
384/// An individual binding in a let block.
385#[derive(Debug, Clone, PartialEq, Eq)]
386pub struct LetBinding<'ast> {
387    pub pattern: Pattern<'ast>,
388    pub metadata: LetMetadata<'ast>,
389    pub value: Ast<'ast>,
390}
391
392/// The metadata that can be attached to a let. It's a subset of [record::FieldMetadata].
393#[derive(Debug, Default, Clone, PartialEq, Eq)]
394pub struct LetMetadata<'ast> {
395    pub doc: Option<&'ast str>,
396    pub annotation: Annotation<'ast>,
397}
398
399impl<'ast> From<LetMetadata<'ast>> for FieldMetadata<'ast> {
400    fn from(let_metadata: LetMetadata<'ast>) -> Self {
401        FieldMetadata {
402            annotation: let_metadata.annotation,
403            doc: let_metadata.doc,
404            ..Default::default()
405        }
406    }
407}
408
409impl<'ast> TryFrom<FieldMetadata<'ast>> for LetMetadata<'ast> {
410    type Error = ();
411
412    fn try_from(field_metadata: FieldMetadata<'ast>) -> Result<Self, Self::Error> {
413        if let FieldMetadata {
414            doc,
415            annotation,
416            opt: false,
417            not_exported: false,
418            priority: MergePriority::Neutral,
419        } = field_metadata
420        {
421            Ok(LetMetadata { doc, annotation })
422        } else {
423            Err(())
424        }
425    }
426}
427
428impl<'ast> Node<'ast> {
429    /// Tries to extract a static literal from string chunks.
430    ///
431    /// This methods returns a `Some(..)` when the term is a [Node::StringChunks] and all the
432    /// chunks are [StringChunk::Literal]
433    pub fn try_str_chunk_as_static_str(&self) -> Option<String> {
434        match self {
435            Node::StringChunks(chunks) => StringChunk::try_chunks_as_static_str(*chunks),
436            _ => None,
437        }
438    }
439
440    /// Attaches a position to this node turning it into an [Ast].
441    pub fn spanned(self, pos: TermPos) -> Ast<'ast> {
442        Ast { node: self, pos }
443    }
444}
445
446/// A Nickel AST. Contains a root node and a span.
447///
448//TODO: we don't expect to access the span much on the happy path. Should we add an indirection
449//through a reference?
450#[derive(Clone, Debug, PartialEq, Eq)]
451pub struct Ast<'ast> {
452    pub node: Node<'ast>,
453    pub pos: TermPos,
454}
455
456impl<'ast> Ast<'ast> {
457    /// Sets a new position for this AST node.
458    pub fn with_pos(self, pos: TermPos) -> Self {
459        Ast { pos, ..self }
460    }
461
462    /// Removes the position from this AST node and (recursively) child nodes.
463    ///
464    /// This is mainly useful for tests, where we often want to compare syntax
465    /// trees without their locations.
466    #[cfg(test)]
467    pub fn without_pos(self, alloc: &'ast AstAlloc) -> Self {
468        self.traverse(
469            alloc,
470            &mut |t: Type| -> Result<_, std::convert::Infallible> {
471                Ok(Type {
472                    pos: TermPos::None,
473                    ..t
474                })
475            },
476            TraverseOrder::BottomUp,
477        )
478        .unwrap()
479        .traverse(
480            alloc,
481            &mut |t: Ast<'_>| -> Result<_, std::convert::Infallible> {
482                let node = match t.node {
483                    Node::Record(r) => Node::Record(alloc.alloc(Record {
484                        field_defs: alloc.alloc_many(r.field_defs.iter().map(|fd| FieldDef {
485                            pos: TermPos::None,
486                            ..fd.clone()
487                        })),
488                        ..r.clone()
489                    })),
490                    n => n,
491                };
492                Ok(Ast {
493                    pos: TermPos::None,
494                    node,
495                })
496            },
497            TraverseOrder::BottomUp,
498        )
499        .unwrap()
500    }
501}
502
503impl Default for Ast<'_> {
504    fn default() -> Self {
505        Ast {
506            node: Node::Null,
507            pos: TermPos::None,
508        }
509    }
510}
511
512/// A branch of a match expression.
513#[derive(Debug, PartialEq, Eq, Clone)]
514pub struct MatchBranch<'ast> {
515    /// The pattern on the left hand side of `=>`.
516    pub pattern: Pattern<'ast>,
517    /// A potential guard, which is an additional side-condition defined as `if cond`. The value
518    /// stored in this field is the boolean condition itself.
519    pub guard: Option<Ast<'ast>>,
520    /// The body of the branch, on the right hand side of `=>`.
521    pub body: Ast<'ast>,
522}
523
524/// Content of a match expression.
525#[derive(Debug, PartialEq, Eq, Clone, Copy)]
526pub struct Match<'ast> {
527    /// Branches of the match expression, where the first component is the pattern on the left hand
528    /// side of `=>` and the second component is the body of the branch.
529    pub branches: &'ast [MatchBranch<'ast>],
530}
531
532/// A type and/or contract annotation.
533#[derive(Debug, PartialEq, Eq, Clone, Default)]
534pub struct Annotation<'ast> {
535    /// The type annotation (using `:`).
536    pub typ: Option<Type<'ast>>,
537
538    /// The contract annotations (using `|`).
539    pub contracts: &'ast [Type<'ast>],
540}
541
542impl Annotation<'_> {
543    /// Returns a string representation of the contracts (without the static type annotation) as a
544    /// comma-separated list.
545    pub fn contracts_to_string(&self) -> Option<String> {
546        todo!("requires pretty printing first")
547        //(!self.contracts.is_empty()).then(|| {
548        //    self.contracts
549        //        .iter()
550        //        .map(|typ| format!("{typ}"))
551        //        .collect::<Vec<_>>()
552        //        .join(",")
553        //})
554    }
555
556    /// Returns `true` if this annotation is empty, i.e. hold neither a type annotation nor
557    /// contracts annotations.
558    pub fn is_empty(&self) -> bool {
559        self.typ.is_none() && self.contracts.is_empty()
560    }
561}
562
563/// Specifies where something should be imported from.
564#[derive(Clone, Debug, PartialEq, Eq, Hash)]
565pub enum Import<'ast> {
566    Path {
567        path: &'ast OsStr,
568        format: InputFormat,
569    },
570    /// Importing packges requires a `PackageMap` to translate the location
571    /// to a path. The format is always Nickel.
572    Package { id: Ident },
573}
574
575impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for Ast<'ast> {
576    /// Traverse through all [Ast] in the tree.
577    ///
578    /// This also recurses into the terms that are contained in [typ::Type] subtrees.
579    fn traverse<F, E>(
580        self,
581        alloc: &'ast AstAlloc,
582        f: &mut F,
583        order: TraverseOrder,
584    ) -> Result<Ast<'ast>, E>
585    where
586        F: FnMut(Ast<'ast>) -> Result<Ast<'ast>, E>,
587    {
588        let ast = match order {
589            TraverseOrder::TopDown => f(self)?,
590            TraverseOrder::BottomUp => self,
591        };
592        let pos = ast.pos;
593
594        let result = match &ast.node {
595            Node::Fun { args, body } => {
596                let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?;
597                let body = alloc.alloc((*body).clone().traverse(alloc, f, order)?);
598
599                Ast {
600                    node: Node::Fun { args, body },
601                    pos,
602                }
603            }
604            Node::Let {
605                bindings,
606                body,
607                rec,
608            } => {
609                let bindings = traverse_alloc_many(alloc, bindings.iter().cloned(), f, order)?;
610                let body = alloc.alloc((*body).clone().traverse(alloc, f, order)?);
611
612                Ast {
613                    node: Node::Let {
614                        bindings,
615                        body,
616                        rec: *rec,
617                    },
618                    pos,
619                }
620            }
621            Node::App { head, args } => {
622                let head = alloc.alloc((*head).clone().traverse(alloc, f, order)?);
623                let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?;
624
625                Ast {
626                    node: Node::App { head, args },
627                    pos,
628                }
629            }
630            Node::Match(data) => {
631                let branches = traverse_alloc_many(alloc, data.branches.iter().cloned(), f, order)?;
632
633                Ast {
634                    node: Node::Match(Match { branches }),
635                    pos,
636                }
637            }
638            Node::PrimOpApp { op, args } => {
639                let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?;
640
641                Ast {
642                    node: Node::PrimOpApp { op, args },
643                    pos,
644                }
645            }
646            Node::Record(record) => {
647                let field_defs =
648                    traverse_alloc_many(alloc, record.field_defs.iter().cloned(), f, order)?;
649
650                Ast {
651                    node: Node::Record(alloc.alloc(Record {
652                        field_defs,
653                        includes: record.includes,
654                        open: record.open,
655                    })),
656                    pos,
657                }
658            }
659            Node::Array(elts) => {
660                let elts = traverse_alloc_many(alloc, elts.iter().cloned(), f, order)?;
661
662                Ast {
663                    node: Node::Array(elts),
664                    pos,
665                }
666            }
667            Node::StringChunks(chunks) => {
668                let chunks_res: Result<Vec<StringChunk<Ast<'ast>>>, E> = chunks
669                    .iter()
670                    .cloned()
671                    .map(|chunk| match chunk {
672                        chunk @ StringChunk::Literal(_) => Ok(chunk),
673                        StringChunk::Expr(ast, indent) => {
674                            Ok(StringChunk::Expr(ast.traverse(alloc, f, order)?, indent))
675                        }
676                    })
677                    .collect();
678
679                Ast {
680                    node: Node::StringChunks(alloc.alloc_many(chunks_res?)),
681                    pos,
682                }
683            }
684            Node::Annotated { annot, inner } => {
685                let annot = alloc.alloc((*annot).clone().traverse(alloc, f, order)?);
686                let inner = alloc.alloc((*inner).clone().traverse(alloc, f, order)?);
687
688                Ast {
689                    node: Node::Annotated { annot, inner },
690                    pos,
691                }
692            }
693            Node::Type(typ) => {
694                let typ = alloc.alloc((*typ).clone().traverse(alloc, f, order)?);
695
696                Ast {
697                    node: Node::Type(typ),
698                    pos,
699                }
700            }
701            Node::IfThenElse {
702                cond,
703                then_branch,
704                else_branch,
705            } => {
706                let cond = alloc.alloc((*cond).clone().traverse(alloc, f, order)?);
707                let then_branch = alloc.alloc((*then_branch).clone().traverse(alloc, f, order)?);
708                let else_branch = alloc.alloc((*else_branch).clone().traverse(alloc, f, order)?);
709
710                Ast {
711                    node: Node::IfThenElse {
712                        cond,
713                        then_branch,
714                        else_branch,
715                    },
716                    pos,
717                }
718            }
719            _ => ast,
720        };
721
722        match order {
723            TraverseOrder::TopDown => Ok(result),
724            TraverseOrder::BottomUp => f(result),
725        }
726    }
727
728    fn traverse_ref<S, U>(
729        &'ast self,
730        f: &mut dyn FnMut(&'ast Ast<'ast>, &S) -> TraverseControl<S, U>,
731        state: &S,
732    ) -> Option<U> {
733        let child_state = match f(self, state) {
734            TraverseControl::Continue => None,
735            TraverseControl::ContinueWithScope(s) => Some(s),
736            TraverseControl::SkipBranch => {
737                return None;
738            }
739            TraverseControl::Return(ret) => {
740                return Some(ret);
741            }
742        };
743        let state = child_state.as_ref().unwrap_or(state);
744
745        match self.node {
746            Node::Null
747            | Node::Bool(_)
748            | Node::Number(_)
749            | Node::String(_)
750            | Node::Var(_)
751            | Node::Import(_)
752            | Node::ParseError(_) => None,
753            Node::IfThenElse {
754                cond,
755                then_branch,
756                else_branch,
757            } => cond
758                .traverse_ref(f, state)
759                .or_else(|| then_branch.traverse_ref(f, state))
760                .or_else(|| else_branch.traverse_ref(f, state)),
761            Node::EnumVariant { tag: _, arg } => arg?.traverse_ref(f, state),
762            Node::StringChunks(chunks) => chunks.iter().find_map(|chk| {
763                if let StringChunk::Expr(term, _) = chk {
764                    term.traverse_ref(f, state)
765                } else {
766                    None
767                }
768            }),
769            Node::Fun { args, body } => args
770                .iter()
771                .find_map(|arg| arg.traverse_ref(f, state))
772                .or_else(|| body.traverse_ref(f, state)),
773            Node::PrimOpApp { op: _, args } => {
774                args.iter().find_map(|arg| arg.traverse_ref(f, state))
775            }
776            Node::Let {
777                bindings,
778                body,
779                rec: _,
780            } => bindings
781                .iter()
782                .find_map(|binding| binding.traverse_ref(f, state))
783                .or_else(|| body.traverse_ref(f, state)),
784            Node::App { head, args } => head
785                .traverse_ref(f, state)
786                .or_else(|| args.iter().find_map(|arg| arg.traverse_ref(f, state))),
787            Node::Record(data) => data
788                .field_defs
789                .iter()
790                .find_map(|field_def| field_def.traverse_ref(f, state)),
791            Node::Match(data) => data.branches.iter().find_map(
792                |MatchBranch {
793                     pattern,
794                     guard,
795                     body,
796                 }| {
797                    pattern
798                        .traverse_ref(f, state)
799                        .or_else(|| {
800                            if let Some(cond) = guard.as_ref() {
801                                cond.traverse_ref(f, state)
802                            } else {
803                                None
804                            }
805                        })
806                        .or_else(|| body.traverse_ref(f, state))
807                },
808            ),
809            Node::Array(elts) => elts.iter().find_map(|t| t.traverse_ref(f, state)),
810            Node::Annotated { annot, inner } => annot
811                .traverse_ref(f, state)
812                .or_else(|| inner.traverse_ref(f, state)),
813            Node::Type(typ) => typ.traverse_ref(f, state),
814        }
815    }
816}
817
818impl<'ast> TraverseAlloc<'ast, Type<'ast>> for Ast<'ast> {
819    fn traverse<F, E>(
820        self,
821        alloc: &'ast AstAlloc,
822        f: &mut F,
823        order: TraverseOrder,
824    ) -> Result<Ast<'ast>, E>
825    where
826        F: FnMut(Type<'ast>) -> Result<Type<'ast>, E>,
827    {
828        self.traverse(
829            alloc,
830            &mut |ast: Ast<'ast>| match &ast.node {
831                Node::Type(typ) => {
832                    let typ = alloc.alloc((*typ).clone().traverse(alloc, f, order)?);
833                    Ok(Ast {
834                        node: Node::Type(typ),
835                        pos: ast.pos,
836                    })
837                }
838                _ => Ok(ast),
839            },
840            order,
841        )
842    }
843
844    fn traverse_ref<S, U>(
845        &'ast self,
846        f: &mut dyn FnMut(&'ast Type<'ast>, &S) -> TraverseControl<S, U>,
847        state: &S,
848    ) -> Option<U> {
849        self.traverse_ref(
850            &mut |ast: &'ast Ast<'ast>, state: &S| match &ast.node {
851                Node::Type(typ) => typ.traverse_ref(f, state).into(),
852                _ => TraverseControl::Continue,
853            },
854            state,
855        )
856    }
857}
858
859impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for Annotation<'ast> {
860    fn traverse<F, E>(
861        self,
862        alloc: &'ast AstAlloc,
863        f: &mut F,
864        order: TraverseOrder,
865    ) -> Result<Self, E>
866    where
867        F: FnMut(Ast<'ast>) -> Result<Ast<'ast>, E>,
868    {
869        let typ = self
870            .typ
871            .map(|typ| typ.traverse(alloc, f, order))
872            .transpose()?;
873        let contracts = traverse_alloc_many(alloc, self.contracts.iter().cloned(), f, order)?;
874
875        Ok(Annotation { typ, contracts })
876    }
877
878    fn traverse_ref<S, U>(
879        &'ast self,
880        f: &mut dyn FnMut(&'ast Ast<'ast>, &S) -> TraverseControl<S, U>,
881        scope: &S,
882    ) -> Option<U> {
883        self.typ
884            .iter()
885            .chain(self.contracts.iter())
886            .find_map(|c| c.traverse_ref(f, scope))
887    }
888}
889
890impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for LetBinding<'ast> {
891    fn traverse<F, E>(
892        self,
893        alloc: &'ast AstAlloc,
894        f: &mut F,
895        order: TraverseOrder,
896    ) -> Result<Self, E>
897    where
898        F: FnMut(Ast<'ast>) -> Result<Ast<'ast>, E>,
899    {
900        let pattern = self.pattern.traverse(alloc, f, order)?;
901
902        let metadata = LetMetadata {
903            annotation: self.metadata.annotation.traverse(alloc, f, order)?,
904            doc: self.metadata.doc,
905        };
906
907        let value = self.value.traverse(alloc, f, order)?;
908
909        Ok(LetBinding {
910            pattern,
911            metadata,
912            value,
913        })
914    }
915
916    fn traverse_ref<S, U>(
917        &'ast self,
918        f: &mut dyn FnMut(&'ast Ast<'ast>, &S) -> TraverseControl<S, U>,
919        scope: &S,
920    ) -> Option<U> {
921        self.metadata
922            .annotation
923            .traverse_ref(f, scope)
924            .or_else(|| self.value.traverse_ref(f, scope))
925    }
926}
927
928impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for MatchBranch<'ast> {
929    fn traverse<F, E>(
930        self,
931        alloc: &'ast AstAlloc,
932        f: &mut F,
933        order: TraverseOrder,
934    ) -> Result<Self, E>
935    where
936        F: FnMut(Ast<'ast>) -> Result<Ast<'ast>, E>,
937    {
938        let pattern = self.pattern.traverse(alloc, f, order)?;
939        let body = self.body.traverse(alloc, f, order)?;
940        let guard = self
941            .guard
942            .map(|guard| guard.traverse(alloc, f, order))
943            .transpose()?;
944
945        Ok(MatchBranch {
946            pattern,
947            guard,
948            body,
949        })
950    }
951
952    fn traverse_ref<S, U>(
953        &'ast self,
954        f: &mut dyn FnMut(&'ast Ast<'ast>, &S) -> TraverseControl<S, U>,
955        scope: &S,
956    ) -> Option<U> {
957        self.pattern
958            .traverse_ref(f, scope)
959            .or_else(|| self.body.traverse_ref(f, scope))
960            .or_else(|| {
961                self.guard
962                    .as_ref()
963                    .and_then(|guard| guard.traverse_ref(f, scope))
964            })
965    }
966}
967
968impl<'ast> From<Node<'ast>> for Ast<'ast> {
969    fn from(node: Node<'ast>) -> Self {
970        Ast {
971            node,
972            pos: TermPos::None,
973        }
974    }
975}
976
977/// Similar to `TryFrom`, but takes an additional allocator for conversion from and to
978/// [crate::ast::Ast] that requires to thread an explicit allocator.
979///
980/// We chose a different name than `try_from` for the method - although it has a different
981/// signature from the standard `TryFrom` (two arguments vs one) - to avoid confusing the compiler
982/// which would otherwise have difficulties disambiguating calls like `Ast::try_from`.
983pub trait TryConvert<'ast, T>
984where
985    Self: Sized,
986{
987    type Error;
988
989    fn try_convert(alloc: &'ast AstAlloc, from: T) -> Result<Self, Self::Error>;
990}
991
992impl_display_from_bytecode_pretty!(Node<'_>);
993impl_display_from_bytecode_pretty!(Ast<'_>);