1use std::ffi::OsStr;
13
14use crate::{
15 cache::InputFormat,
16 error::ParseError,
17 identifier::{Ident, LocIdent},
18 impl_display_from_bytecode_pretty,
19 position::TermPos,
20 traverse::*,
21};
22
23pub use crate::term::{MergePriority, Number, StrChunk as StringChunk};
25
26pub mod alloc;
27pub mod builder;
28pub mod combine;
29pub mod compat;
30pub mod pattern;
31pub mod primop;
32pub mod record;
33pub mod typ;
34
35pub use alloc::AstAlloc;
36use pattern::*;
37use primop::PrimOp;
38use record::*;
39use typ::*;
40
41#[derive(Clone, Debug, PartialEq, Eq, Default)]
52pub enum Node<'ast> {
53 #[default]
55 Null,
56
57 Bool(bool),
59
60 Number(&'ast Number),
65
66 String(&'ast str),
68
69 StringChunks(&'ast [StringChunk<Ast<'ast>>]),
75
76 Fun {
78 args: &'ast [Pattern<'ast>],
79 body: &'ast Ast<'ast>,
80 },
81
82 Let {
84 bindings: &'ast [LetBinding<'ast>],
85 body: &'ast Ast<'ast>,
86 rec: bool,
87 },
88
89 App {
91 head: &'ast Ast<'ast>,
92 args: &'ast [Ast<'ast>],
93 },
94
95 Var(LocIdent),
97
98 EnumVariant {
104 tag: LocIdent,
105 arg: Option<&'ast Ast<'ast>>,
106 },
107
108 Record(&'ast Record<'ast>),
110
111 IfThenElse {
113 cond: &'ast Ast<'ast>,
114 then_branch: &'ast Ast<'ast>,
115 else_branch: &'ast Ast<'ast>,
116 },
117
118 Match(Match<'ast>),
120
121 Array(&'ast [Ast<'ast>]),
123
124 PrimOpApp {
129 op: &'ast PrimOp,
130 args: &'ast [Ast<'ast>],
131 },
132
133 Annotated {
135 annot: &'ast Annotation<'ast>,
136 inner: &'ast Ast<'ast>,
137 },
138
139 Import(Import<'ast>),
141
142 Type(&'ast Type<'ast>),
146
147 ParseError(&'ast ParseError),
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
154pub struct LetBinding<'ast> {
155 pub pattern: Pattern<'ast>,
156 pub metadata: LetMetadata<'ast>,
157 pub value: Ast<'ast>,
158}
159
160#[derive(Debug, Default, Clone, PartialEq, Eq)]
162pub struct LetMetadata<'ast> {
163 pub doc: Option<&'ast str>,
164 pub annotation: Annotation<'ast>,
165}
166
167impl<'ast> From<LetMetadata<'ast>> for FieldMetadata<'ast> {
168 fn from(let_metadata: LetMetadata<'ast>) -> Self {
169 FieldMetadata {
170 annotation: let_metadata.annotation,
171 doc: let_metadata.doc,
172 ..Default::default()
173 }
174 }
175}
176
177impl<'ast> TryFrom<FieldMetadata<'ast>> for LetMetadata<'ast> {
178 type Error = ();
179
180 fn try_from(field_metadata: FieldMetadata<'ast>) -> Result<Self, Self::Error> {
181 if let FieldMetadata {
182 doc,
183 annotation,
184 opt: false,
185 not_exported: false,
186 priority: MergePriority::Neutral,
187 } = field_metadata
188 {
189 Ok(LetMetadata { doc, annotation })
190 } else {
191 Err(())
192 }
193 }
194}
195
196impl<'ast> Node<'ast> {
197 pub fn try_str_chunk_as_static_str(&self) -> Option<String> {
202 match self {
203 Node::StringChunks(chunks) => StringChunk::try_chunks_as_static_str(*chunks),
204 _ => None,
205 }
206 }
207
208 pub fn spanned(self, pos: TermPos) -> Ast<'ast> {
210 Ast { node: self, pos }
211 }
212}
213
214#[derive(Clone, Debug, PartialEq, Eq)]
219pub struct Ast<'ast> {
220 pub node: Node<'ast>,
221 pub pos: TermPos,
222}
223
224impl Ast<'_> {
225 pub fn with_pos(self, pos: TermPos) -> Self {
227 Ast { pos, ..self }
228 }
229}
230
231impl Default for Ast<'_> {
232 fn default() -> Self {
233 Ast {
234 node: Node::Null,
235 pos: TermPos::None,
236 }
237 }
238}
239
240#[derive(Debug, PartialEq, Eq, Clone)]
242pub struct MatchBranch<'ast> {
243 pub pattern: Pattern<'ast>,
245 pub guard: Option<Ast<'ast>>,
248 pub body: Ast<'ast>,
250}
251
252#[derive(Debug, PartialEq, Eq, Clone, Copy)]
254pub struct Match<'ast> {
255 pub branches: &'ast [MatchBranch<'ast>],
258}
259
260#[derive(Debug, PartialEq, Eq, Clone, Default)]
262pub struct Annotation<'ast> {
263 pub typ: Option<Type<'ast>>,
265
266 pub contracts: &'ast [Type<'ast>],
268}
269
270impl Annotation<'_> {
271 pub fn contracts_to_string(&self) -> Option<String> {
274 todo!("requires pretty printing first")
275 }
283
284 pub fn is_empty(&self) -> bool {
287 self.typ.is_none() && self.contracts.is_empty()
288 }
289}
290
291#[derive(Clone, Debug, PartialEq, Eq, Hash)]
293pub enum Import<'ast> {
294 Path {
295 path: &'ast OsStr,
296 format: InputFormat,
297 },
298 Package { id: Ident },
301}
302
303impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for Ast<'ast> {
304 fn traverse<F, E>(
308 self,
309 alloc: &'ast AstAlloc,
310 f: &mut F,
311 order: TraverseOrder,
312 ) -> Result<Ast<'ast>, E>
313 where
314 F: FnMut(Ast<'ast>) -> Result<Ast<'ast>, E>,
315 {
316 let ast = match order {
317 TraverseOrder::TopDown => f(self)?,
318 TraverseOrder::BottomUp => self,
319 };
320 let pos = ast.pos;
321
322 let result = match &ast.node {
323 Node::Fun { args, body } => {
324 let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?;
325 let body = alloc.alloc((*body).clone().traverse(alloc, f, order)?);
326
327 Ast {
328 node: Node::Fun { args, body },
329 pos,
330 }
331 }
332 Node::Let {
333 bindings,
334 body,
335 rec,
336 } => {
337 let bindings = traverse_alloc_many(alloc, bindings.iter().cloned(), f, order)?;
338 let body = alloc.alloc((*body).clone().traverse(alloc, f, order)?);
339
340 Ast {
341 node: Node::Let {
342 bindings,
343 body,
344 rec: *rec,
345 },
346 pos,
347 }
348 }
349 Node::App { head, args } => {
350 let head = alloc.alloc((*head).clone().traverse(alloc, f, order)?);
351 let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?;
352
353 Ast {
354 node: Node::App { head, args },
355 pos,
356 }
357 }
358 Node::Match(data) => {
359 let branches = traverse_alloc_many(alloc, data.branches.iter().cloned(), f, order)?;
360
361 Ast {
362 node: Node::Match(Match { branches }),
363 pos,
364 }
365 }
366 Node::PrimOpApp { op, args } => {
367 let args = traverse_alloc_many(alloc, args.iter().cloned(), f, order)?;
368
369 Ast {
370 node: Node::PrimOpApp { op, args },
371 pos,
372 }
373 }
374 Node::Record(record) => {
375 let field_defs =
376 traverse_alloc_many(alloc, record.field_defs.iter().cloned(), f, order)?;
377
378 Ast {
379 node: Node::Record(alloc.alloc(Record {
380 field_defs,
381 includes: record.includes,
382 open: record.open,
383 })),
384 pos,
385 }
386 }
387 Node::Array(elts) => {
388 let elts = traverse_alloc_many(alloc, elts.iter().cloned(), f, order)?;
389
390 Ast {
391 node: Node::Array(elts),
392 pos,
393 }
394 }
395 Node::StringChunks(chunks) => {
396 let chunks_res: Result<Vec<StringChunk<Ast<'ast>>>, E> = chunks
397 .iter()
398 .cloned()
399 .map(|chunk| match chunk {
400 chunk @ StringChunk::Literal(_) => Ok(chunk),
401 StringChunk::Expr(ast, indent) => {
402 Ok(StringChunk::Expr(ast.traverse(alloc, f, order)?, indent))
403 }
404 })
405 .collect();
406
407 Ast {
408 node: Node::StringChunks(alloc.alloc_many(chunks_res?)),
409 pos,
410 }
411 }
412 Node::Annotated { annot, inner } => {
413 let annot = alloc.alloc((*annot).clone().traverse(alloc, f, order)?);
414 let inner = alloc.alloc((*inner).clone().traverse(alloc, f, order)?);
415
416 Ast {
417 node: Node::Annotated { annot, inner },
418 pos,
419 }
420 }
421 Node::Type(typ) => {
422 let typ = alloc.alloc((*typ).clone().traverse(alloc, f, order)?);
423
424 Ast {
425 node: Node::Type(typ),
426 pos,
427 }
428 }
429 _ => ast,
430 };
431
432 match order {
433 TraverseOrder::TopDown => Ok(result),
434 TraverseOrder::BottomUp => f(result),
435 }
436 }
437
438 fn traverse_ref<S, U>(
439 &'ast self,
440 f: &mut dyn FnMut(&'ast Ast<'ast>, &S) -> TraverseControl<S, U>,
441 state: &S,
442 ) -> Option<U> {
443 let child_state = match f(self, state) {
444 TraverseControl::Continue => None,
445 TraverseControl::ContinueWithScope(s) => Some(s),
446 TraverseControl::SkipBranch => {
447 return None;
448 }
449 TraverseControl::Return(ret) => {
450 return Some(ret);
451 }
452 };
453 let state = child_state.as_ref().unwrap_or(state);
454
455 match self.node {
456 Node::Null
457 | Node::Bool(_)
458 | Node::Number(_)
459 | Node::String(_)
460 | Node::Var(_)
461 | Node::Import(_)
462 | Node::ParseError(_) => None,
463 Node::IfThenElse {
464 cond,
465 then_branch,
466 else_branch,
467 } => cond
468 .traverse_ref(f, state)
469 .or_else(|| then_branch.traverse_ref(f, state))
470 .or_else(|| else_branch.traverse_ref(f, state)),
471 Node::EnumVariant { tag: _, arg } => arg?.traverse_ref(f, state),
472 Node::StringChunks(chunks) => chunks.iter().find_map(|chk| {
473 if let StringChunk::Expr(term, _) = chk {
474 term.traverse_ref(f, state)
475 } else {
476 None
477 }
478 }),
479 Node::Fun { args, body } => args
480 .iter()
481 .find_map(|arg| arg.traverse_ref(f, state))
482 .or_else(|| body.traverse_ref(f, state)),
483 Node::PrimOpApp { op: _, args } => {
484 args.iter().find_map(|arg| arg.traverse_ref(f, state))
485 }
486 Node::Let {
487 bindings,
488 body,
489 rec: _,
490 } => bindings
491 .iter()
492 .find_map(|binding| binding.traverse_ref(f, state))
493 .or_else(|| body.traverse_ref(f, state)),
494 Node::App { head, args } => head
495 .traverse_ref(f, state)
496 .or_else(|| args.iter().find_map(|arg| arg.traverse_ref(f, state))),
497 Node::Record(data) => data
498 .field_defs
499 .iter()
500 .find_map(|field_def| field_def.traverse_ref(f, state)),
501 Node::Match(data) => data.branches.iter().find_map(
502 |MatchBranch {
503 pattern,
504 guard,
505 body,
506 }| {
507 pattern
508 .traverse_ref(f, state)
509 .or_else(|| {
510 if let Some(cond) = guard.as_ref() {
511 cond.traverse_ref(f, state)
512 } else {
513 None
514 }
515 })
516 .or_else(|| body.traverse_ref(f, state))
517 },
518 ),
519 Node::Array(elts) => elts.iter().find_map(|t| t.traverse_ref(f, state)),
520 Node::Annotated { annot, inner } => annot
521 .traverse_ref(f, state)
522 .or_else(|| inner.traverse_ref(f, state)),
523 Node::Type(typ) => typ.traverse_ref(f, state),
524 }
525 }
526}
527
528impl<'ast> TraverseAlloc<'ast, Type<'ast>> for Ast<'ast> {
529 fn traverse<F, E>(
530 self,
531 alloc: &'ast AstAlloc,
532 f: &mut F,
533 order: TraverseOrder,
534 ) -> Result<Ast<'ast>, E>
535 where
536 F: FnMut(Type<'ast>) -> Result<Type<'ast>, E>,
537 {
538 self.traverse(
539 alloc,
540 &mut |ast: Ast<'ast>| match &ast.node {
541 Node::Type(typ) => {
542 let typ = alloc.alloc((*typ).clone().traverse(alloc, f, order)?);
543 Ok(Ast {
544 node: Node::Type(typ),
545 pos: ast.pos,
546 })
547 }
548 _ => Ok(ast),
549 },
550 order,
551 )
552 }
553
554 fn traverse_ref<S, U>(
555 &'ast self,
556 f: &mut dyn FnMut(&'ast Type<'ast>, &S) -> TraverseControl<S, U>,
557 state: &S,
558 ) -> Option<U> {
559 self.traverse_ref(
560 &mut |ast: &'ast Ast<'ast>, state: &S| match &ast.node {
561 Node::Type(typ) => typ.traverse_ref(f, state).into(),
562 _ => TraverseControl::Continue,
563 },
564 state,
565 )
566 }
567}
568
569impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for Annotation<'ast> {
570 fn traverse<F, E>(
571 self,
572 alloc: &'ast AstAlloc,
573 f: &mut F,
574 order: TraverseOrder,
575 ) -> Result<Self, E>
576 where
577 F: FnMut(Ast<'ast>) -> Result<Ast<'ast>, E>,
578 {
579 let typ = self
580 .typ
581 .map(|typ| typ.traverse(alloc, f, order))
582 .transpose()?;
583 let contracts = traverse_alloc_many(alloc, self.contracts.iter().cloned(), f, order)?;
584
585 Ok(Annotation { typ, contracts })
586 }
587
588 fn traverse_ref<S, U>(
589 &'ast self,
590 f: &mut dyn FnMut(&'ast Ast<'ast>, &S) -> TraverseControl<S, U>,
591 scope: &S,
592 ) -> Option<U> {
593 self.typ
594 .iter()
595 .chain(self.contracts.iter())
596 .find_map(|c| c.traverse_ref(f, scope))
597 }
598}
599
600impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for LetBinding<'ast> {
601 fn traverse<F, E>(
602 self,
603 alloc: &'ast AstAlloc,
604 f: &mut F,
605 order: TraverseOrder,
606 ) -> Result<Self, E>
607 where
608 F: FnMut(Ast<'ast>) -> Result<Ast<'ast>, E>,
609 {
610 let pattern = self.pattern.traverse(alloc, f, order)?;
611
612 let metadata = LetMetadata {
613 annotation: self.metadata.annotation.traverse(alloc, f, order)?,
614 doc: self.metadata.doc,
615 };
616
617 let value = self.value.traverse(alloc, f, order)?;
618
619 Ok(LetBinding {
620 pattern,
621 metadata,
622 value,
623 })
624 }
625
626 fn traverse_ref<S, U>(
627 &'ast self,
628 f: &mut dyn FnMut(&'ast Ast<'ast>, &S) -> TraverseControl<S, U>,
629 scope: &S,
630 ) -> Option<U> {
631 self.metadata
632 .annotation
633 .traverse_ref(f, scope)
634 .or_else(|| self.value.traverse_ref(f, scope))
635 }
636}
637
638impl<'ast> TraverseAlloc<'ast, Ast<'ast>> for MatchBranch<'ast> {
639 fn traverse<F, E>(
640 self,
641 alloc: &'ast AstAlloc,
642 f: &mut F,
643 order: TraverseOrder,
644 ) -> Result<Self, E>
645 where
646 F: FnMut(Ast<'ast>) -> Result<Ast<'ast>, E>,
647 {
648 let pattern = self.pattern.traverse(alloc, f, order)?;
649 let body = self.body.traverse(alloc, f, order)?;
650 let guard = self
651 .guard
652 .map(|guard| guard.traverse(alloc, f, order))
653 .transpose()?;
654
655 Ok(MatchBranch {
656 pattern,
657 guard,
658 body,
659 })
660 }
661
662 fn traverse_ref<S, U>(
663 &'ast self,
664 f: &mut dyn FnMut(&'ast Ast<'ast>, &S) -> TraverseControl<S, U>,
665 scope: &S,
666 ) -> Option<U> {
667 self.pattern
668 .traverse_ref(f, scope)
669 .or_else(|| self.body.traverse_ref(f, scope))
670 .or_else(|| {
671 self.guard
672 .as_ref()
673 .and_then(|guard| guard.traverse_ref(f, scope))
674 })
675 }
676}
677
678impl<'ast> From<Node<'ast>> for Ast<'ast> {
679 fn from(node: Node<'ast>) -> Self {
680 Ast {
681 node,
682 pos: TermPos::None,
683 }
684 }
685}
686
687pub(crate) trait TryConvert<'ast, T>
694where
695 Self: Sized,
696{
697 type Error;
698
699 fn try_convert(alloc: &'ast AstAlloc, from: T) -> Result<Self, Self::Error>;
700}
701
702impl_display_from_bytecode_pretty!(Node<'_>);
703impl_display_from_bytecode_pretty!(Ast<'_>);