Skip to main content

numbat/
typed_ast.rs

1use std::sync::Arc;
2
3use compact_str::{CompactString, ToCompactString, format_compact};
4use indexmap::IndexMap;
5use itertools::Itertools;
6
7use crate::arithmetic::Exponent;
8pub use crate::ast::{BinaryOperator, TypeExpression, UnaryOperator};
9use crate::ast::{ProcedureKind, TypeAnnotation, TypeParameterBound};
10use crate::dimension::DimensionRegistry;
11use crate::pretty_print::escape_numbat_string;
12use crate::traversal::{ForAllExpressions, ForAllTypeSchemes};
13use crate::type_variable::TypeVariable;
14use crate::typechecker::TypeCheckError;
15use crate::typechecker::qualified_type::QualifiedType;
16use crate::typechecker::type_scheme::TypeScheme;
17use crate::{BaseRepresentation, BaseRepresentationFactor, markup as m};
18use crate::{
19    decorator::Decorator, markup::Markup, number::Number, prefix::Prefix,
20    prefix_parser::AcceptsPrefix, pretty_print::PrettyPrint, span::Span,
21};
22use num_traits::{CheckedAdd, CheckedMul};
23
24/// Dimension type
25#[derive(Clone, Debug, PartialEq, Eq)]
26pub enum DTypeFactor {
27    TVar(TypeVariable),
28    TPar(CompactString),
29    BaseDimension(CompactString),
30}
31
32impl DTypeFactor {
33    pub fn name(&self) -> &str {
34        match self {
35            DTypeFactor::TVar(TypeVariable::Named(name)) => name,
36            DTypeFactor::TVar(TypeVariable::Quantified(_)) => unreachable!(),
37            DTypeFactor::TPar(name) => name,
38            DTypeFactor::BaseDimension(name) => name,
39        }
40    }
41}
42
43type DtypeFactorPower = (DTypeFactor, Exponent);
44
45#[derive(Clone, Debug, PartialEq, Eq)]
46pub struct DType {
47    // Always in canonical form
48    factors: Arc<Vec<DtypeFactorPower>>,
49}
50
51impl DType {
52    pub fn factors(&self) -> &[DtypeFactorPower] {
53        &self.factors
54    }
55
56    pub fn into_factors(self) -> Arc<Vec<DtypeFactorPower>> {
57        self.factors
58    }
59
60    pub fn from_factors(factors: Arc<Vec<DtypeFactorPower>>) -> DType {
61        let mut dtype = DType { factors };
62        dtype.canonicalize();
63        dtype
64    }
65
66    pub fn scalar() -> DType {
67        DType::from_factors(Arc::new(vec![]))
68    }
69
70    pub fn is_scalar(&self) -> bool {
71        self == &Self::scalar()
72    }
73
74    pub fn to_readable_type(&self, registry: &DimensionRegistry) -> m::Markup {
75        if self.is_scalar() {
76            return m::type_identifier("Scalar");
77        }
78
79        let mut names = vec![];
80
81        if self.factors.len() == 1 && self.factors[0].1 == Exponent::from_integer(1) {
82            names.push(self.factors[0].0.name().to_compact_string());
83        }
84
85        let base_representation = self.to_base_representation();
86        names.extend(registry.get_derived_entry_names_for(&base_representation));
87        match &names[..] {
88            [] => self.pretty_print(),
89            [single] => m::type_identifier(single.to_compact_string()),
90            multiple => Itertools::intersperse(
91                multiple.iter().cloned().map(m::type_identifier),
92                m::dimmed(" or "),
93            )
94            .sum(),
95        }
96    }
97
98    /// Is the current dimension type the Time dimension?
99    ///
100    /// This is special helper that's useful when dealing with DateTimes
101    pub fn is_time_dimension(&self) -> bool {
102        *self == DType::base_dimension("Time")
103    }
104
105    pub fn from_type_variable(v: TypeVariable) -> DType {
106        DType::from_factors(Arc::new(vec![(
107            DTypeFactor::TVar(v),
108            Exponent::from_integer(1),
109        )]))
110    }
111
112    pub fn from_type_parameter(name: CompactString) -> DType {
113        DType::from_factors(Arc::new(vec![(
114            DTypeFactor::TPar(name),
115            Exponent::from_integer(1),
116        )]))
117    }
118
119    pub fn deconstruct_as_single_type_variable(&self) -> Option<TypeVariable> {
120        match &self.factors[..] {
121            [(DTypeFactor::TVar(v), exponent)] if exponent == &Exponent::from_integer(1) => {
122                Some(v.clone())
123            }
124            _ => None,
125        }
126    }
127
128    pub fn from_tgen(i: usize) -> DType {
129        DType::from_factors(Arc::new(vec![(
130            DTypeFactor::TVar(TypeVariable::Quantified(i)),
131            Exponent::from_integer(1),
132        )]))
133    }
134
135    pub fn base_dimension(name: &str) -> DType {
136        DType::from_factors(Arc::new(vec![(
137            DTypeFactor::BaseDimension(name.into()),
138            Exponent::from_integer(1),
139        )]))
140    }
141
142    fn canonicalize(&mut self) {
143        self.try_canonicalize()
144            .expect("overflow in dimension type exponent computation");
145    }
146
147    /// Canonicalize with overflow checking. Returns `None` if an overflow occurs.
148    fn try_canonicalize(&mut self) -> Option<()> {
149        // Move all type-variable and tgen factors to the front, sort by name
150        Arc::make_mut(&mut self.factors).sort_by(|(f1, _), (f2, _)| match (f1, f2) {
151            (DTypeFactor::TVar(v1), DTypeFactor::TVar(v2)) => v1.cmp(v2),
152            (DTypeFactor::TVar(_), _) => std::cmp::Ordering::Less,
153
154            (DTypeFactor::BaseDimension(d1), DTypeFactor::BaseDimension(d2)) => d1.cmp(d2),
155            (DTypeFactor::BaseDimension(_), DTypeFactor::TVar(_)) => std::cmp::Ordering::Greater,
156            (DTypeFactor::BaseDimension(_), DTypeFactor::TPar(_)) => std::cmp::Ordering::Less,
157
158            (DTypeFactor::TPar(p1), DTypeFactor::TPar(p2)) => p1.cmp(p2),
159            (DTypeFactor::TPar(_), _) => std::cmp::Ordering::Greater,
160        });
161
162        // Merge powers of equal factors:
163        let mut new_factors: Vec<DtypeFactorPower> = Vec::new();
164        for (f, n) in self.factors.iter() {
165            if let Some((last_f, last_n)) = new_factors.last_mut()
166                && f == last_f
167            {
168                *last_n = last_n.checked_add(n)?;
169                continue;
170            }
171            new_factors.push((f.clone(), *n));
172        }
173
174        // Remove factors with zero exponent:
175        new_factors.retain(|(_, n)| *n != Exponent::from_integer(0));
176
177        self.factors = Arc::new(new_factors);
178        Some(())
179    }
180
181    /// Like `from_factors`, but returns `None` if an overflow occurs during canonicalization.
182    pub fn try_from_factors(factors: Arc<Vec<DtypeFactorPower>>) -> Option<DType> {
183        let mut dtype = DType { factors };
184        dtype.try_canonicalize()?;
185        Some(dtype)
186    }
187
188    pub fn multiply(&self, other: &DType) -> DType {
189        let mut factors = self.factors.clone();
190        Arc::make_mut(&mut factors).extend(other.factors.iter().cloned());
191        DType::from_factors(factors)
192    }
193
194    /// Like `multiply`, but returns `None` if an overflow occurs.
195    pub fn try_multiply(&self, other: &DType) -> Option<DType> {
196        let mut factors = self.factors.clone();
197        Arc::make_mut(&mut factors).extend(other.factors.iter().cloned());
198        DType::try_from_factors(factors)
199    }
200
201    pub fn power(&self, n: Exponent) -> DType {
202        let factors = self
203            .factors
204            .iter()
205            .map(|(f, m)| (f.clone(), n * m))
206            .collect();
207        DType::from_factors(Arc::new(factors))
208    }
209
210    /// Like `power`, but returns `None` if the exponent computation overflows.
211    pub fn try_power(&self, n: Exponent) -> Option<DType> {
212        let factors: Option<Vec<_>> = self
213            .factors
214            .iter()
215            .map(|(f, m)| n.checked_mul(m).map(|exp| (f.clone(), exp)))
216            .collect();
217        factors.and_then(|f| DType::try_from_factors(Arc::new(f)))
218    }
219
220    pub fn inverse(&self) -> DType {
221        self.power(-Exponent::from_integer(1))
222    }
223
224    pub fn divide(&self, other: &DType) -> DType {
225        self.multiply(&other.inverse())
226    }
227
228    /// Like `divide`, but returns `None` if an overflow occurs.
229    pub fn try_divide(&self, other: &DType) -> Option<DType> {
230        self.try_multiply(&other.inverse())
231    }
232
233    pub fn type_variables(&self, including_type_parameters: bool) -> Vec<TypeVariable> {
234        let mut vars: Vec<_> = self
235            .factors
236            .iter()
237            .filter_map(|(f, _)| match f {
238                DTypeFactor::TVar(v) => Some(v.clone()),
239                DTypeFactor::TPar(v) => {
240                    if including_type_parameters {
241                        Some(TypeVariable::new(v))
242                    } else {
243                        None
244                    }
245                }
246                DTypeFactor::BaseDimension(_) => None,
247            })
248            .collect();
249        vars.sort();
250        vars.dedup();
251        vars
252    }
253
254    pub fn contains(&self, name: &TypeVariable, including_type_parameters: bool) -> bool {
255        self.type_variables(including_type_parameters)
256            .contains(name)
257    }
258
259    pub fn split_first_factor(&self) -> Option<(&DtypeFactorPower, &[DtypeFactorPower])> {
260        self.factors.split_first()
261    }
262
263    fn instantiate(&self, type_variables: &[TypeVariable]) -> DType {
264        let mut factors = Vec::new();
265
266        for (f, n) in self.factors.iter() {
267            match f {
268                DTypeFactor::TVar(TypeVariable::Quantified(i)) => {
269                    factors.push((DTypeFactor::TVar(type_variables[*i].clone()), *n));
270                }
271                _ => {
272                    factors.push((f.clone(), *n));
273                }
274            }
275        }
276        Self::from_factors(Arc::new(factors))
277    }
278
279    pub fn to_base_representation(&self) -> BaseRepresentation {
280        let mut factors = vec![];
281        for (f, n) in self.factors.iter() {
282            match f {
283                DTypeFactor::BaseDimension(name) => {
284                    factors.push(BaseRepresentationFactor(name.clone(), *n));
285                }
286                DTypeFactor::TVar(TypeVariable::Named(name)) => {
287                    factors.push(BaseRepresentationFactor(name.clone(), *n));
288                }
289                DTypeFactor::TVar(TypeVariable::Quantified(id)) => {
290                    // Quantified type variables can appear during constraint solving
291                    // before they're fully resolved
292                    factors.push(BaseRepresentationFactor(format!("?{id}").into(), *n));
293                }
294                DTypeFactor::TPar(name) => {
295                    factors.push(BaseRepresentationFactor(name.clone(), *n));
296                }
297            }
298        }
299        BaseRepresentation::from_factors(factors)
300    }
301}
302
303impl PrettyPrint for DType {
304    fn pretty_print(&self) -> Markup {
305        self.to_base_representation().pretty_print()
306    }
307}
308
309impl std::fmt::Display for DType {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        write!(f, "{}", self.pretty_print())
312    }
313}
314
315impl From<BaseRepresentation> for DType {
316    fn from(base_representation: BaseRepresentation) -> Self {
317        let factors = base_representation
318            .into_iter()
319            .map(|BaseRepresentationFactor(name, exp)| (DTypeFactor::BaseDimension(name), exp))
320            .collect();
321        DType::from_factors(Arc::new(factors))
322    }
323}
324
325/// Represents whether a struct is a generic definition or a concrete instance
326#[derive(Debug, Clone, PartialEq, Eq)]
327pub enum StructKind {
328    /// Generic struct definition with type parameters (e.g., `struct Vec<D: Dim>`)
329    Definition(Vec<(Span, CompactString, Option<TypeParameterBound>)>),
330    /// Instantiated struct with concrete type arguments (e.g., `Vec<Length>`)
331    Instance(Vec<Type>),
332}
333
334#[derive(Debug, Clone, PartialEq, Eq)]
335pub struct StructInfo {
336    pub definition_span: Span,
337    pub name: CompactString,
338    pub kind: StructKind,
339    pub fields: IndexMap<CompactString, (Span, Type)>,
340}
341
342/// A monomorphic type (no quantifiers).
343///
344/// - `TVar`: Unification variable, to be solved during type inference. Example: when
345///   type-checking `fn f(x) = x + 1`, the parameter `x` gets a fresh `TVar(Named("T0"))`.
346/// - `TPar`: User-written type parameter in a polymorphic definition. Example: in
347///   `fn square<D: Dim>(x: D) -> D²`, the `D` in type annotations becomes `TPar("D")`.
348///
349/// During generalization, both `TVar` and `TPar` become `TVar(Quantified(i))` in a `TypeScheme`.
350#[derive(Debug, Clone, PartialEq, Eq)]
351pub enum Type {
352    TVar(TypeVariable),
353    TPar(CompactString),
354    Dimension(DType),
355    Boolean,
356    String,
357    DateTime,
358    Fn(Vec<Type>, Box<Type>),
359    Struct(Box<StructInfo>),
360    List(Box<Type>),
361}
362
363impl std::fmt::Display for Type {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        match self {
366            Type::TVar(TypeVariable::Named(name)) => write!(f, "{name}"),
367            Type::TVar(TypeVariable::Quantified(_)) => {
368                unreachable!("Quantified types should not be printed")
369            }
370            Type::TPar(name) => write!(f, "{name}"),
371            Type::Dimension(d) => d.fmt(f),
372            Type::Boolean => write!(f, "Bool"),
373            Type::String => write!(f, "String"),
374            Type::DateTime => write!(f, "DateTime"),
375            Type::Fn(param_types, return_type) => {
376                write!(
377                    f,
378                    "Fn[({ps}) -> {return_type}]",
379                    ps = param_types.iter().map(|p| p.to_string()).join(", ")
380                )
381            }
382            Type::Struct(info) => {
383                write!(f, "{}", info.name)?;
384                if let StructKind::Instance(type_args) = &info.kind
385                    && !type_args.is_empty()
386                {
387                    write!(
388                        f,
389                        "<{}>",
390                        type_args.iter().map(|t| t.to_string()).join(", ")
391                    )?;
392                }
393                write!(
394                    f,
395                    " {{{}}}",
396                    info.fields
397                        .iter()
398                        .map(|(n, (_, t))| n.to_string() + ": " + &t.to_string())
399                        .join(", ")
400                )
401            }
402            Type::List(element_type) => write!(f, "List<{element_type}>"),
403        }
404    }
405}
406
407impl PrettyPrint for Type {
408    fn pretty_print(&self) -> Markup {
409        match self {
410            Type::TVar(TypeVariable::Named(name)) => m::type_identifier(name.to_compact_string()),
411            Type::TVar(TypeVariable::Quantified(_)) => {
412                unreachable!("Quantified types should not be printed")
413            }
414            Type::TPar(name) => m::type_identifier(name.clone()),
415            Type::Dimension(d) => d.pretty_print(),
416            Type::Boolean => m::type_identifier("Bool"),
417            Type::String => m::type_identifier("String"),
418            Type::DateTime => m::type_identifier("DateTime"),
419            Type::Fn(param_types, return_type) => {
420                m::type_identifier("Fn")
421                    + m::operator("[(")
422                    + Itertools::intersperse(
423                        param_types.iter().map(|t| t.pretty_print()),
424                        m::operator(",") + m::space(),
425                    )
426                    .sum()
427                    + m::operator(")")
428                    + m::space()
429                    + m::operator("->")
430                    + m::space()
431                    + return_type.pretty_print()
432                    + m::operator("]")
433            }
434            Type::Struct(info) => {
435                let mut markup = m::type_identifier(info.name.clone());
436                if let StructKind::Instance(type_args) = &info.kind
437                    && !type_args.is_empty()
438                {
439                    markup += m::operator("<");
440                    markup += Itertools::intersperse(
441                        type_args.iter().map(|t| t.pretty_print()),
442                        m::operator(",") + m::space(),
443                    )
444                    .sum();
445                    markup += m::operator(">");
446                }
447                markup
448            }
449            Type::List(element_type) => {
450                m::type_identifier("List")
451                    + m::operator("<")
452                    + element_type.pretty_print()
453                    + m::operator(">")
454            }
455        }
456    }
457}
458
459impl Type {
460    pub fn to_readable_type(&self, registry: &DimensionRegistry) -> Markup {
461        match self {
462            Type::Dimension(d) => d.to_readable_type(registry),
463            Type::Struct(info) => {
464                let mut markup = m::type_identifier(info.name.clone());
465                if let StructKind::Instance(type_args) = &info.kind
466                    && !type_args.is_empty()
467                {
468                    markup += m::operator("<");
469                    markup += Itertools::intersperse(
470                        type_args.iter().map(|t| t.to_readable_type(registry)),
471                        m::operator(",") + m::space(),
472                    )
473                    .sum();
474                    markup += m::operator(">");
475                }
476                markup
477            }
478            Type::List(element_type) => {
479                m::type_identifier("List")
480                    + m::operator("<")
481                    + element_type.to_readable_type(registry)
482                    + m::operator(">")
483            }
484            _ => self.pretty_print(),
485        }
486    }
487
488    pub fn scalar() -> Type {
489        Type::Dimension(DType::scalar())
490    }
491
492    pub fn is_dtype(&self) -> bool {
493        matches!(self, Type::Dimension(..))
494    }
495
496    pub fn is_fn_type(&self) -> bool {
497        matches!(self, Type::Fn(..))
498    }
499
500    /// Check if two types have incompatible type constructors (i.e., they can never be unified).
501    /// For example, a Dimension type can never equal a Fn type, even if one contains type variables.
502    /// This is used for early detection of type mismatches in the constraint solver.
503    pub(crate) fn has_incompatible_constructor(&self, other: &Type) -> bool {
504        use Type::*;
505
506        // A Dimension type containing only a type variable can be unified with any type
507        let is_dimension_with_only_tvar = |t: &Type| matches!(t, Dimension(d) if d.deconstruct_as_single_type_variable().is_some());
508
509        match (self, other) {
510            // Type variables and type parameters can potentially match anything
511            (TVar(_), _) | (_, TVar(_)) | (TPar(_), _) | (_, TPar(_)) => false,
512
513            // A Dimension containing only a type variable can match anything
514            (t1, t2) if is_dimension_with_only_tvar(t1) || is_dimension_with_only_tvar(t2) => false,
515
516            // Same constructors might be compatible (need further unification)
517            (Dimension(_), Dimension(_))
518            | (Boolean, Boolean)
519            | (String, String)
520            | (DateTime, DateTime)
521            | (Fn(_, _), Fn(_, _))
522            | (Struct(_), Struct(_))
523            | (List(_), List(_)) => false,
524
525            // Different concrete constructors are incompatible
526            _ => true,
527        }
528    }
529
530    pub(crate) fn type_variables(&self, including_type_parameters: bool) -> Vec<TypeVariable> {
531        match self {
532            Type::TVar(v) => vec![v.clone()],
533            Type::TPar(n) => {
534                if including_type_parameters {
535                    vec![TypeVariable::new(n)]
536                } else {
537                    vec![]
538                }
539            }
540            Type::Dimension(d) => d.type_variables(including_type_parameters),
541            Type::Boolean | Type::String | Type::DateTime => vec![],
542            Type::Fn(param_types, return_type) => {
543                let mut vars = return_type.type_variables(including_type_parameters);
544                for param_type in param_types {
545                    vars.extend(param_type.type_variables(including_type_parameters));
546                }
547                vars.sort();
548                vars.dedup();
549                vars
550            }
551            Type::Struct(info) => {
552                let mut vars = vec![];
553                for (_, (_, t)) in &info.fields {
554                    vars.extend(t.type_variables(including_type_parameters));
555                }
556                vars
557            }
558            Type::List(element_type) => element_type.type_variables(including_type_parameters),
559        }
560    }
561
562    pub(crate) fn contains(&self, x: &TypeVariable, including_type_parameters: bool) -> bool {
563        self.type_variables(including_type_parameters).contains(x)
564    }
565
566    /// A type is called 'closed' if it does not change under substitutions (contains no unification variables)
567    pub(crate) fn is_closed(&self) -> bool {
568        self.type_variables(false).is_empty()
569    }
570
571    pub(crate) fn instantiate(&self, type_variables: &[TypeVariable]) -> Type {
572        match self {
573            Type::TVar(TypeVariable::Quantified(i)) => Type::TVar(type_variables[*i].clone()),
574            Type::TVar(v) => Type::TVar(v.clone()),
575            Type::TPar(n) => Type::TPar(n.clone()),
576            Type::Dimension(d) => Type::Dimension(d.instantiate(type_variables)),
577            Type::Boolean | Type::String | Type::DateTime => self.clone(),
578            Type::Fn(param_types, return_type) => Type::Fn(
579                param_types
580                    .iter()
581                    .map(|t| t.instantiate(type_variables))
582                    .collect(),
583                Box::new(return_type.instantiate(type_variables)),
584            ),
585            Type::Struct(info) => {
586                let instantiated_fields = info
587                    .fields
588                    .iter()
589                    .map(|(name, (span, field_type))| {
590                        (
591                            name.clone(),
592                            (*span, field_type.instantiate(type_variables)),
593                        )
594                    })
595                    .collect();
596                let instantiated_kind = match &info.kind {
597                    StructKind::Definition(params) => StructKind::Definition(params.clone()),
598                    StructKind::Instance(type_args) => StructKind::Instance(
599                        type_args
600                            .iter()
601                            .map(|t| t.instantiate(type_variables))
602                            .collect(),
603                    ),
604                };
605                Type::Struct(Box::new(StructInfo {
606                    definition_span: info.definition_span,
607                    name: info.name.clone(),
608                    kind: instantiated_kind,
609                    fields: instantiated_fields,
610                }))
611            }
612            Type::List(element_type) => {
613                Type::List(Box::new(element_type.instantiate(type_variables)))
614            }
615        }
616    }
617
618    pub(crate) fn is_scalar(&self) -> bool {
619        match self {
620            Type::Dimension(d) => d.is_scalar(),
621            _ => false,
622        }
623    }
624}
625
626#[derive(Debug, Clone, PartialEq)]
627pub enum StringPart<'a> {
628    Fixed(CompactString),
629    Interpolation {
630        span: Span,
631        expr: Box<Expression<'a>>,
632        format_specifiers: Option<&'a str>,
633    },
634}
635
636impl PrettyPrint for StringPart<'_> {
637    fn pretty_print(&self) -> Markup {
638        match self {
639            StringPart::Fixed(s) => m::string(escape_numbat_string(s)),
640            StringPart::Interpolation {
641                span: _,
642                expr,
643                format_specifiers,
644            } => {
645                let mut markup = m::operator("{") + expr.pretty_print();
646
647                if let Some(format_specifiers) = format_specifiers {
648                    markup += m::text(format_specifiers.to_compact_string());
649                }
650
651                markup += m::operator("}");
652
653                markup
654            }
655        }
656    }
657}
658
659impl PrettyPrint for &Vec<StringPart<'_>> {
660    fn pretty_print(&self) -> Markup {
661        m::operator("\"") + self.iter().map(|p| p.pretty_print()).sum() + m::operator("\"")
662    }
663}
664
665#[derive(Debug, Clone, PartialEq)]
666pub enum Expression<'a> {
667    Scalar {
668        span: Span,
669        value: Number,
670        type_scheme: TypeScheme,
671    },
672    Identifier {
673        span: Span,
674        name: &'a str,
675        type_scheme: TypeScheme,
676    },
677    UnitIdentifier {
678        span: Span,
679        prefix: Prefix,
680        name: CompactString,
681        full_name: CompactString,
682        type_scheme: TypeScheme,
683    },
684    UnaryOperator {
685        span: Span,
686        op: UnaryOperator,
687        expr: Box<Expression<'a>>,
688        type_scheme: TypeScheme,
689    },
690    BinaryOperator {
691        op_span: Option<Span>,
692        op: BinaryOperator,
693        lhs: Box<Expression<'a>>,
694        rhs: Box<Expression<'a>>,
695        type_scheme: TypeScheme,
696    },
697    /// A special binary operator that has a DateTime as one (or both) of the operands
698    BinaryOperatorForDate {
699        op_span: Option<Span>,
700        op: BinaryOperator,
701        /// LHS must evaluate to a DateTime
702        lhs: Box<Expression<'a>>,
703        /// RHS can evaluate to a DateTime or a quantity of type Time
704        rhs: Box<Expression<'a>>,
705        type_scheme: TypeScheme,
706    },
707    /// A 'proper' function call
708    FunctionCall {
709        full_span: Span,
710        ident_span: Span,
711        name: &'a str,
712        args: Vec<Expression<'a>>,
713        type_scheme: TypeScheme,
714    },
715    /// A call via a function object
716    CallableCall {
717        full_span: Span,
718        callable: Box<Expression<'a>>,
719        args: Vec<Expression<'a>>,
720        type_scheme: TypeScheme,
721    },
722    Boolean(Span, bool),
723    Condition {
724        span: Span,
725        condition: Box<Expression<'a>>,
726        then_expr: Box<Expression<'a>>,
727        else_expr: Box<Expression<'a>>,
728    },
729    String(Span, Vec<StringPart<'a>>),
730    InstantiateStruct {
731        span: Span,
732        fields: Vec<(&'a str, Expression<'a>)>,
733        struct_info: StructInfo,
734    },
735    AccessField {
736        full_span: Span,
737        ident_span: Span,
738        expr: Box<Expression<'a>>,
739        field_name: &'a str,
740        struct_type: TypeScheme,
741        field_type: TypeScheme,
742    },
743    List {
744        span: Span,
745        elements: Vec<Expression<'a>>,
746        type_scheme: TypeScheme,
747    },
748    TypedHole(Span, TypeScheme),
749}
750
751impl Expression<'_> {
752    pub fn full_span(&self) -> Span {
753        match self {
754            Expression::Scalar { span, .. } => *span,
755            Expression::Identifier { span, .. } => *span,
756            Expression::UnitIdentifier { span, .. } => *span,
757            Expression::UnaryOperator { span, expr, .. } => span.extend(&expr.full_span()),
758            Expression::BinaryOperator {
759                op_span, lhs, rhs, ..
760            } => {
761                let mut span = lhs.full_span().extend(&rhs.full_span());
762                if let Some(op_span) = op_span {
763                    span = span.extend(op_span);
764                }
765                span
766            }
767            Expression::BinaryOperatorForDate {
768                op_span, lhs, rhs, ..
769            } => {
770                let mut span = lhs.full_span().extend(&rhs.full_span());
771                if let Some(op_span) = op_span {
772                    span = span.extend(op_span);
773                }
774                span
775            }
776            Expression::FunctionCall { full_span, .. } => *full_span,
777            Expression::CallableCall { full_span, .. } => *full_span,
778            Expression::Boolean(span, _) => *span,
779            Expression::Condition {
780                span, else_expr, ..
781            } => span.extend(&else_expr.full_span()),
782            Expression::String(span, _) => *span,
783            Expression::InstantiateStruct { span, .. } => *span,
784            Expression::AccessField { full_span, .. } => *full_span,
785            Expression::List { span, .. } => *span,
786            Expression::TypedHole(span, _) => *span,
787        }
788    }
789}
790
791#[derive(Debug, Clone, PartialEq)]
792pub struct DefineVariable<'a> {
793    pub name: &'a str,
794    pub decorators: Vec<Decorator<'a>>,
795    pub expr: Expression<'a>,
796    pub type_annotation: Option<TypeAnnotation>,
797    pub type_scheme: TypeScheme,
798    pub readable_type: Markup,
799}
800
801#[derive(Debug, Clone, PartialEq)]
802pub enum Statement<'a> {
803    Expression(Expression<'a>),
804    DefineVariable(DefineVariable<'a>),
805    DefineFunction {
806        function_name: &'a str,
807        decorators: Vec<Decorator<'a>>,
808        type_parameters: Vec<(&'a str, Option<TypeParameterBound>)>,
809        parameters: Vec<(
810            Span,                   // span of the parameter
811            &'a str,                // parameter name
812            Option<TypeAnnotation>, // parameter type annotation
813            Markup,                 // readable parameter type
814        )>,
815        body: Option<Expression<'a>>,
816        local_variables: Vec<DefineVariable<'a>>,
817        fn_type: TypeScheme,
818        return_type_annotation: Option<TypeAnnotation>,
819        readable_return_type: Markup,
820    },
821    DefineDimension(&'a str, Vec<TypeExpression>),
822    DefineBaseUnit {
823        name: &'a str,
824        identifier_span: Span,
825        decorators: Vec<Decorator<'a>>,
826        type_annotation: Option<TypeAnnotation>,
827        type_scheme: TypeScheme,
828    },
829    DefineDerivedUnit {
830        name: &'a str,
831        identifier_span: Span,
832        expr: Expression<'a>,
833        decorators: Vec<Decorator<'a>>,
834        type_annotation: Option<TypeAnnotation>,
835        type_scheme: TypeScheme,
836        readable_type: Markup,
837    },
838    ProcedureCall {
839        kind: ProcedureKind,
840        span: Span,
841        args: Vec<Expression<'a>>,
842    },
843    DefineStruct(StructInfo),
844}
845
846impl Statement<'_> {
847    pub fn as_expression(&self) -> Option<&Expression<'_>> {
848        if let Self::Expression(v) = self {
849            Some(v)
850        } else {
851            None
852        }
853    }
854
855    pub(crate) fn generalize_types(&mut self, dtype_variables: &[TypeVariable]) {
856        self.for_all_type_schemes(&mut |type_: &mut TypeScheme| type_.generalize(dtype_variables));
857    }
858
859    fn create_readable_type(
860        registry: &DimensionRegistry,
861        type_: &TypeScheme,
862        annotation: &Option<TypeAnnotation>,
863        with_quantifiers: bool,
864    ) -> Markup {
865        if let Some(annotation) = annotation {
866            annotation.pretty_print()
867        } else {
868            type_.to_readable_type(registry, with_quantifiers)
869        }
870    }
871
872    pub(crate) fn update_readable_types(&mut self, registry: &DimensionRegistry) {
873        match self {
874            Statement::Expression(_) => {}
875            Statement::DefineVariable(DefineVariable {
876                type_annotation,
877                type_scheme,
878                readable_type,
879                ..
880            }) => {
881                *readable_type =
882                    Self::create_readable_type(registry, type_scheme, type_annotation, true);
883            }
884            Statement::DefineFunction {
885                type_parameters,
886                parameters,
887                local_variables,
888                fn_type,
889                return_type_annotation,
890                readable_return_type,
891                ..
892            } => {
893                let (fn_type, _) =
894                    fn_type.instantiate_for_printing(Some(type_parameters.iter().map(|(n, _)| *n)));
895
896                for DefineVariable {
897                    type_annotation,
898                    type_scheme,
899                    readable_type,
900                    ..
901                } in local_variables
902                {
903                    *readable_type =
904                        Self::create_readable_type(registry, type_scheme, type_annotation, false);
905                }
906
907                let Type::Fn(parameter_types, return_type) = fn_type.inner else {
908                    unreachable!("Expected a function type")
909                };
910
911                *readable_return_type = Self::create_readable_type(
912                    registry,
913                    &TypeScheme::concrete(*return_type),
914                    return_type_annotation,
915                    false,
916                );
917
918                for ((_, _, type_annotation, readable_parameter_type), parameter_type) in
919                    parameters.iter_mut().zip(parameter_types.iter())
920                {
921                    *readable_parameter_type = Self::create_readable_type(
922                        registry,
923                        &TypeScheme::concrete(parameter_type.clone()),
924                        type_annotation,
925                        false,
926                    );
927                }
928            }
929            Statement::DefineDimension(_, _) => {}
930            Statement::DefineBaseUnit { .. } => {}
931            Statement::DefineDerivedUnit {
932                type_annotation,
933                type_scheme,
934                readable_type,
935                ..
936            } => {
937                *readable_type =
938                    Self::create_readable_type(registry, type_scheme, type_annotation, false);
939            }
940            Statement::ProcedureCall { .. } => {}
941            Statement::DefineStruct(_) => {}
942        }
943    }
944
945    pub(crate) fn exponents_for(&mut self, tv: &TypeVariable) -> Vec<Exponent> {
946        // TODO: things to not need to be mutable in this function
947        let mut exponents = vec![];
948        self.for_all_type_schemes(&mut |type_: &mut TypeScheme| {
949            if let Type::Dimension(dtype) = type_.unsafe_as_concrete() {
950                for (factor, exp) in dtype.factors.iter() {
951                    if factor == &DTypeFactor::TVar(tv.clone()) {
952                        exponents.push(*exp)
953                    }
954                }
955            }
956        });
957        exponents
958    }
959
960    pub(crate) fn find_typed_hole(
961        &self,
962    ) -> Result<Option<(Span, TypeScheme)>, Box<TypeCheckError>> {
963        let mut hole = None;
964        let mut found_multiple_holes = false;
965        self.for_all_expressions(&mut |expr| {
966            if let Expression::TypedHole(span, type_) = expr {
967                if hole.is_some() {
968                    found_multiple_holes = true;
969                }
970                hole = Some((*span, type_.clone()))
971            }
972        });
973
974        if found_multiple_holes {
975            Err(Box::new(TypeCheckError::MultipleTypedHoles(
976                hole.unwrap().0,
977            )))
978        } else {
979            Ok(hole)
980        }
981    }
982
983    /// Returns an iterator over local bindings (name, type) in the statement.
984    /// This includes local variables in `where` clauses and function parameters.
985    pub(crate) fn local_bindings(&self) -> Vec<(&str, TypeScheme)> {
986        match self {
987            Statement::DefineFunction {
988                parameters,
989                local_variables,
990                fn_type,
991                ..
992            } => {
993                let mut bindings = Vec::new();
994
995                if let TypeScheme::Concrete(Type::Fn(param_types, _))
996                | TypeScheme::Quantified(
997                    _,
998                    crate::typechecker::qualified_type::QualifiedType {
999                        inner: Type::Fn(param_types, _),
1000                        ..
1001                    },
1002                ) = fn_type
1003                {
1004                    for ((_, param_name, _, _), param_type) in
1005                        parameters.iter().zip(param_types.iter())
1006                    {
1007                        bindings
1008                            .push((*param_name, TypeScheme::make_quantified(param_type.clone())));
1009                    }
1010                }
1011
1012                for DefineVariable {
1013                    name, type_scheme, ..
1014                } in local_variables
1015                {
1016                    bindings.push((*name, type_scheme.clone()));
1017                }
1018
1019                bindings
1020            }
1021            _ => Vec::new(),
1022        }
1023    }
1024}
1025
1026impl Expression<'_> {
1027    pub fn get_type(&self) -> Type {
1028        match self {
1029            Expression::Scalar { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1030            Expression::Identifier { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1031            Expression::UnitIdentifier { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1032            Expression::UnaryOperator { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1033            Expression::BinaryOperator { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1034            Expression::BinaryOperatorForDate { type_scheme, .. } => {
1035                type_scheme.unsafe_as_concrete()
1036            }
1037            Expression::FunctionCall { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1038            Expression::CallableCall { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1039            Expression::Boolean(_, _) => Type::Boolean,
1040            Expression::Condition { then_expr, .. } => then_expr.get_type(),
1041            Expression::String(_, _) => Type::String,
1042            Expression::InstantiateStruct { struct_info, .. } => {
1043                Type::Struct(Box::new(struct_info.clone()))
1044            }
1045            Expression::AccessField { field_type, .. } => field_type.unsafe_as_concrete(),
1046            Expression::List { type_scheme, .. } => {
1047                Type::List(Box::new(type_scheme.unsafe_as_concrete()))
1048            }
1049            Expression::TypedHole(_, type_) => type_.unsafe_as_concrete(),
1050        }
1051    }
1052
1053    pub fn get_type_scheme(&self) -> TypeScheme {
1054        match self {
1055            Expression::Scalar { type_scheme, .. } => type_scheme.clone(),
1056            Expression::Identifier { type_scheme, .. } => type_scheme.clone(),
1057            Expression::UnitIdentifier { type_scheme, .. } => type_scheme.clone(),
1058            Expression::UnaryOperator { type_scheme, .. } => type_scheme.clone(),
1059            Expression::BinaryOperator { type_scheme, .. } => type_scheme.clone(),
1060            Expression::BinaryOperatorForDate { type_scheme, .. } => type_scheme.clone(),
1061            Expression::FunctionCall { type_scheme, .. } => type_scheme.clone(),
1062            Expression::CallableCall { type_scheme, .. } => type_scheme.clone(),
1063            Expression::Boolean(_, _) => TypeScheme::make_quantified(Type::Boolean),
1064            Expression::Condition { then_expr, .. } => then_expr.get_type_scheme(),
1065            Expression::String(_, _) => TypeScheme::make_quantified(Type::String),
1066            Expression::InstantiateStruct { struct_info, .. } => {
1067                TypeScheme::make_quantified(Type::Struct(Box::new(struct_info.clone())))
1068            }
1069            Expression::AccessField { field_type, .. } => field_type.clone(),
1070            Expression::List { type_scheme, .. } => match type_scheme {
1071                TypeScheme::Concrete(t) => TypeScheme::Concrete(Type::List(Box::new(t.clone()))),
1072                TypeScheme::Quantified(ngen, qt) => TypeScheme::Quantified(
1073                    *ngen,
1074                    crate::typechecker::qualified_type::QualifiedType {
1075                        inner: Type::List(Box::new(qt.inner.clone())),
1076                        bounds: qt.bounds.clone(),
1077                    },
1078                ),
1079            },
1080            Expression::TypedHole(_, type_) => type_.clone(),
1081        }
1082    }
1083}
1084
1085fn accepts_prefix_markup(accepts_prefix: &Option<AcceptsPrefix>) -> Markup {
1086    if let Some(accepts_prefix) = accepts_prefix {
1087        m::operator(":")
1088            + m::space()
1089            + match accepts_prefix {
1090                AcceptsPrefix {
1091                    short: true,
1092                    long: true,
1093                } => m::keyword("both"),
1094                AcceptsPrefix {
1095                    short: true,
1096                    long: false,
1097                } => m::keyword("short"),
1098                AcceptsPrefix {
1099                    short: false,
1100                    long: true,
1101                } => m::keyword("long"),
1102                AcceptsPrefix {
1103                    short: false,
1104                    long: false,
1105                } => m::keyword("none"),
1106            }
1107    } else {
1108        m::empty()
1109    }
1110}
1111
1112fn decorator_markup(decorators: &Vec<Decorator>) -> Markup {
1113    let mut markup_decorators = m::empty();
1114    for decorator in decorators {
1115        markup_decorators = markup_decorators
1116            + match decorator {
1117                Decorator::MetricPrefixes => m::decorator("@metric_prefixes"),
1118                Decorator::BinaryPrefixes => m::decorator("@binary_prefixes"),
1119                Decorator::Abbreviation => m::decorator("@abbreviation"),
1120                Decorator::Aliases(names) => {
1121                    m::decorator("@aliases")
1122                        + m::operator("(")
1123                        + Itertools::intersperse(
1124                            names.iter().map(|(name, accepts_prefix, _)| {
1125                                m::unit(name.to_compact_string())
1126                                    + accepts_prefix_markup(accepts_prefix)
1127                            }),
1128                            m::operator(", "),
1129                        )
1130                        .sum()
1131                        + m::operator(")")
1132                }
1133                Decorator::Url(url) => {
1134                    m::decorator("@url")
1135                        + m::operator("(")
1136                        + m::string(url.clone())
1137                        + m::operator(")")
1138                }
1139                Decorator::Name(name) => {
1140                    m::decorator("@name")
1141                        + m::operator("(")
1142                        + m::string(name.clone())
1143                        + m::operator(")")
1144                }
1145                Decorator::Description(description) => {
1146                    m::decorator("@description")
1147                        + m::operator("(")
1148                        + m::string(description.clone())
1149                        + m::operator(")")
1150                }
1151                Decorator::Example(example_code, example_description) => {
1152                    m::decorator("@example")
1153                        + m::operator("(")
1154                        + m::string(example_code.clone())
1155                        + if let Some(example_description) = example_description {
1156                            m::operator(", ") + m::string(example_description.clone())
1157                        } else {
1158                            m::empty()
1159                        }
1160                        + m::operator(")")
1161                }
1162            }
1163            + m::nl();
1164    }
1165    markup_decorators
1166}
1167
1168pub fn pretty_print_function_signature<'a>(
1169    function_name: &str,
1170    fn_type: &QualifiedType,
1171    type_parameters: &[TypeVariable],
1172    parameters: impl Iterator<
1173        Item = (
1174            &'a str, // parameter name
1175            Markup,  // readable parameter type
1176        ),
1177    >,
1178    readable_return_type: &Markup,
1179) -> Markup {
1180    let markup_type_parameters = if type_parameters.is_empty() {
1181        m::empty()
1182    } else {
1183        m::operator("<")
1184            + Itertools::intersperse(
1185                type_parameters.iter().map(|tv| {
1186                    m::type_identifier(tv.unsafe_name().to_compact_string())
1187                        + if fn_type.bounds.is_dtype_bound(tv) {
1188                            m::operator(":") + m::space() + m::type_identifier("Dim")
1189                        } else {
1190                            m::empty()
1191                        }
1192                }),
1193                m::operator(", "),
1194            )
1195            .sum()
1196            + m::operator(">")
1197    };
1198
1199    let markup_parameters = Itertools::intersperse(
1200        parameters.map(|(name, parameter_type)| {
1201            m::identifier(name.to_compact_string()) + m::operator(":") + m::space() + parameter_type
1202        }),
1203        m::operator(", "),
1204    )
1205    .sum();
1206
1207    let markup_return_type =
1208        m::space() + m::operator("->") + m::space() + readable_return_type.clone();
1209
1210    m::keyword("fn")
1211        + m::space()
1212        + m::identifier(function_name.to_compact_string())
1213        + markup_type_parameters
1214        + m::operator("(")
1215        + markup_parameters
1216        + m::operator(")")
1217        + markup_return_type
1218}
1219
1220impl PrettyPrint for Statement<'_> {
1221    fn pretty_print(&self) -> Markup {
1222        match self {
1223            Statement::DefineVariable(DefineVariable {
1224                name,
1225                expr,
1226                readable_type,
1227                ..
1228            }) => {
1229                m::keyword("let")
1230                    + m::space()
1231                    + m::identifier(name.to_compact_string())
1232                    + m::operator(":")
1233                    + m::space()
1234                    + readable_type.clone()
1235                    + m::space()
1236                    + m::operator("=")
1237                    + m::space()
1238                    + expr.pretty_print()
1239            }
1240            Statement::DefineFunction {
1241                function_name,
1242                type_parameters,
1243                parameters,
1244                body,
1245                local_variables,
1246                fn_type,
1247                readable_return_type,
1248                ..
1249            } => {
1250                let (fn_type, type_parameters) =
1251                    fn_type.instantiate_for_printing(Some(type_parameters.iter().map(|(n, _)| *n)));
1252
1253                let mut pretty_local_variables = None;
1254                let mut first = true;
1255                if !local_variables.is_empty() {
1256                    let mut plv = m::empty();
1257                    for DefineVariable {
1258                        name,
1259                        expr,
1260                        readable_type,
1261                        ..
1262                    } in local_variables
1263                    {
1264                        let introducer_keyword = if first {
1265                            first = false;
1266                            m::space() + m::space() + m::keyword("where")
1267                        } else {
1268                            m::space() + m::space() + m::space() + m::space() + m::keyword("and")
1269                        };
1270
1271                        plv += m::nl()
1272                            + introducer_keyword
1273                            + m::space()
1274                            + m::identifier(name.to_compact_string())
1275                            + m::operator(":")
1276                            + m::space()
1277                            + readable_type.clone()
1278                            + m::space()
1279                            + m::operator("=")
1280                            + m::space()
1281                            + expr.pretty_print();
1282                    }
1283                    pretty_local_variables = Some(plv);
1284                }
1285
1286                pretty_print_function_signature(
1287                    function_name,
1288                    &fn_type,
1289                    &type_parameters,
1290                    parameters
1291                        .iter()
1292                        .map(|(_, name, _, type_)| (*name, type_.clone())),
1293                    readable_return_type,
1294                ) + body
1295                    .as_ref()
1296                    .map(|e| m::space() + m::operator("=") + m::space() + e.pretty_print())
1297                    .unwrap_or_default()
1298                    + pretty_local_variables.unwrap_or_default()
1299            }
1300            Statement::Expression(expr) => expr.pretty_print(),
1301            Statement::DefineDimension(identifier, dexprs) if dexprs.is_empty() => {
1302                m::keyword("dimension")
1303                    + m::space()
1304                    + m::type_identifier(identifier.to_compact_string())
1305            }
1306            Statement::DefineDimension(identifier, dexprs) => {
1307                m::keyword("dimension")
1308                    + m::space()
1309                    + m::type_identifier(identifier.to_compact_string())
1310                    + m::space()
1311                    + m::operator("=")
1312                    + m::space()
1313                    + Itertools::intersperse(
1314                        dexprs.iter().map(|d| d.pretty_print()),
1315                        m::space() + m::operator("=") + m::space(),
1316                    )
1317                    .sum()
1318            }
1319            Statement::DefineBaseUnit {
1320                name,
1321                decorators,
1322                type_annotation,
1323                type_scheme,
1324                ..
1325            } => {
1326                decorator_markup(decorators)
1327                    + m::keyword("unit")
1328                    + m::space()
1329                    + m::unit(name.to_compact_string())
1330                    + m::operator(":")
1331                    + m::space()
1332                    + type_annotation
1333                        .as_ref()
1334                        .map(|a: &TypeAnnotation| a.pretty_print())
1335                        .unwrap_or(type_scheme.pretty_print())
1336            }
1337            Statement::DefineDerivedUnit {
1338                name,
1339                expr,
1340                decorators,
1341                readable_type,
1342                ..
1343            } => {
1344                decorator_markup(decorators)
1345                    + m::keyword("unit")
1346                    + m::space()
1347                    + m::unit(name.to_compact_string())
1348                    + m::operator(":")
1349                    + m::space()
1350                    + readable_type.clone()
1351                    + m::space()
1352                    + m::operator("=")
1353                    + m::space()
1354                    + expr.pretty_print()
1355            }
1356            Statement::ProcedureCall { kind, args, .. } => {
1357                let identifier = match kind {
1358                    ProcedureKind::Print => "print",
1359                    ProcedureKind::Assert => "assert",
1360                    ProcedureKind::AssertEq => "assert_eq",
1361                    ProcedureKind::Type => "type",
1362                };
1363                m::identifier(identifier)
1364                    + m::operator("(")
1365                    + Itertools::intersperse(
1366                        args.iter().map(|a| a.pretty_print()),
1367                        m::operator(",") + m::space(),
1368                    )
1369                    .sum()
1370                    + m::operator(")")
1371            }
1372            Statement::DefineStruct(StructInfo { name, fields, .. }) => {
1373                m::keyword("struct")
1374                    + m::space()
1375                    + m::type_identifier(name.clone())
1376                    + m::space()
1377                    + m::operator("{")
1378                    + if fields.is_empty() {
1379                        m::empty()
1380                    } else {
1381                        m::space()
1382                            + Itertools::intersperse(
1383                                fields.iter().map(|(n, (_, t))| {
1384                                    m::identifier(n.clone())
1385                                        + m::operator(":")
1386                                        + m::space()
1387                                        + t.pretty_print()
1388                                }),
1389                                m::operator(",") + m::space(),
1390                            )
1391                            .sum()
1392                            + m::space()
1393                    }
1394                    + m::operator("}")
1395            }
1396        }
1397    }
1398}
1399
1400fn pretty_scalar(n: Number) -> Markup {
1401    m::value(n.pretty_print())
1402}
1403
1404fn with_parens(expr: &Expression) -> Markup {
1405    match expr {
1406        Expression::Scalar { .. }
1407        | Expression::Identifier { .. }
1408        | Expression::UnitIdentifier { .. }
1409        | Expression::FunctionCall { .. }
1410        | Expression::CallableCall { .. }
1411        | Expression::Boolean(..)
1412        | Expression::String(..)
1413        | Expression::InstantiateStruct { .. }
1414        | Expression::AccessField { .. }
1415        | Expression::List { .. }
1416        | Expression::TypedHole(_, _) => expr.pretty_print(),
1417        Expression::UnaryOperator { .. }
1418        | Expression::BinaryOperator { .. }
1419        | Expression::BinaryOperatorForDate { .. }
1420        | Expression::Condition { .. } => m::operator("(") + expr.pretty_print() + m::operator(")"),
1421    }
1422}
1423
1424/// Add parens, if needed -- liberal version, can not be used for exponentiation.
1425fn with_parens_liberal(expr: &Expression) -> Markup {
1426    match expr {
1427        Expression::BinaryOperator {
1428            op: BinaryOperator::Mul,
1429            lhs,
1430            rhs,
1431            ..
1432        } if matches!(**lhs, Expression::Scalar { .. })
1433            && matches!(**rhs, Expression::UnitIdentifier { .. }) =>
1434        {
1435            expr.pretty_print()
1436        }
1437        _ => with_parens(expr),
1438    }
1439}
1440
1441fn pretty_print_binop(op: &BinaryOperator, lhs: &Expression, rhs: &Expression) -> Markup {
1442    match op {
1443        BinaryOperator::ConvertTo => {
1444            // never needs parens, it has the lowest precedence:
1445            lhs.pretty_print() + op.pretty_print() + rhs.pretty_print()
1446        }
1447        BinaryOperator::Mul => match (lhs, rhs) {
1448            (
1449                Expression::Scalar { value: s, .. },
1450                Expression::UnitIdentifier {
1451                    prefix, full_name, ..
1452                },
1453            ) => {
1454                // Fuse multiplication of a scalar and a unit to a quantity
1455                pretty_scalar(*s)
1456                    + m::space()
1457                    + m::unit(format_compact!("{}{}", prefix.as_string_long(), full_name))
1458            }
1459            (Expression::Scalar { value: s, .. }, Expression::Identifier { name, .. }) => {
1460                // Fuse multiplication of a scalar and identifier
1461                pretty_scalar(*s) + m::space() + m::identifier(name.to_compact_string())
1462            }
1463            _ => {
1464                let add_parens_if_needed = |expr: &Expression| {
1465                    if matches!(
1466                        expr,
1467                        Expression::BinaryOperator {
1468                            op: BinaryOperator::Power,
1469                            ..
1470                        } | Expression::BinaryOperator {
1471                            op: BinaryOperator::Mul,
1472                            ..
1473                        }
1474                    ) {
1475                        expr.pretty_print()
1476                    } else {
1477                        with_parens_liberal(expr)
1478                    }
1479                };
1480
1481                add_parens_if_needed(lhs) + op.pretty_print() + add_parens_if_needed(rhs)
1482            }
1483        },
1484        BinaryOperator::Div => {
1485            let lhs_add_parens_if_needed = |expr: &Expression| {
1486                if matches!(
1487                    expr,
1488                    Expression::BinaryOperator {
1489                        op: BinaryOperator::Power,
1490                        ..
1491                    } | Expression::BinaryOperator {
1492                        op: BinaryOperator::Mul,
1493                        ..
1494                    }
1495                ) {
1496                    expr.pretty_print()
1497                } else {
1498                    with_parens_liberal(expr)
1499                }
1500            };
1501            let rhs_add_parens_if_needed = |expr: &Expression| {
1502                if matches!(
1503                    expr,
1504                    Expression::BinaryOperator {
1505                        op: BinaryOperator::Power,
1506                        ..
1507                    }
1508                ) {
1509                    expr.pretty_print()
1510                } else {
1511                    with_parens_liberal(expr)
1512                }
1513            };
1514
1515            lhs_add_parens_if_needed(lhs) + op.pretty_print() + rhs_add_parens_if_needed(rhs)
1516        }
1517        BinaryOperator::Add => {
1518            let add_parens_if_needed = |expr: &Expression| {
1519                if matches!(
1520                    expr,
1521                    Expression::BinaryOperator {
1522                        op: BinaryOperator::Power,
1523                        ..
1524                    } | Expression::BinaryOperator {
1525                        op: BinaryOperator::Mul,
1526                        ..
1527                    } | Expression::BinaryOperator {
1528                        op: BinaryOperator::Add,
1529                        ..
1530                    }
1531                ) {
1532                    expr.pretty_print()
1533                } else {
1534                    with_parens_liberal(expr)
1535                }
1536            };
1537
1538            add_parens_if_needed(lhs) + op.pretty_print() + add_parens_if_needed(rhs)
1539        }
1540        BinaryOperator::Sub => {
1541            let add_parens_if_needed = |expr: &Expression| {
1542                if matches!(
1543                    expr,
1544                    Expression::BinaryOperator {
1545                        op: BinaryOperator::Power,
1546                        ..
1547                    } | Expression::BinaryOperator {
1548                        op: BinaryOperator::Mul,
1549                        ..
1550                    }
1551                ) {
1552                    expr.pretty_print()
1553                } else {
1554                    with_parens_liberal(expr)
1555                }
1556            };
1557
1558            add_parens_if_needed(lhs) + op.pretty_print() + add_parens_if_needed(rhs)
1559        }
1560        BinaryOperator::Power if matches!(rhs, Expression::Scalar { value, .. } if value.to_f64() == 2.0) => {
1561            with_parens(lhs) + m::operator("²")
1562        }
1563        BinaryOperator::Power if matches!(rhs, Expression::Scalar { value, .. } if value.to_f64() == 3.0) => {
1564            with_parens(lhs) + m::operator("³")
1565        }
1566        _ => with_parens(lhs) + op.pretty_print() + with_parens(rhs),
1567    }
1568}
1569
1570impl PrettyPrint for Expression<'_> {
1571    fn pretty_print(&self) -> Markup {
1572        use Expression::*;
1573
1574        match self {
1575            Scalar { value, .. } => pretty_scalar(*value),
1576            Identifier { name, .. } => m::identifier(name.to_compact_string()),
1577            UnitIdentifier {
1578                prefix, full_name, ..
1579            } => m::unit(format_compact!("{}{}", prefix.as_string_long(), full_name)),
1580            UnaryOperator {
1581                op: self::UnaryOperator::Negate,
1582                expr,
1583                ..
1584            } => m::operator("-") + with_parens(expr),
1585            UnaryOperator {
1586                op: self::UnaryOperator::Factorial(order),
1587                expr,
1588                ..
1589            } => with_parens(expr) + (0..order.get()).map(|_| m::operator("!")).sum(),
1590            UnaryOperator {
1591                op: self::UnaryOperator::LogicalNeg,
1592                expr,
1593                ..
1594            } => m::operator("!") + with_parens(expr),
1595            BinaryOperator { op, lhs, rhs, .. } => pretty_print_binop(op, lhs, rhs),
1596            BinaryOperatorForDate { op, lhs, rhs, .. } => pretty_print_binop(op, lhs, rhs),
1597            FunctionCall { name, args, .. } => {
1598                // Special case: render special temperature conversion functions in their sugar form:
1599                if args.len() == 1 {
1600                    // from_celsius(x) / from_fahrenheit(x) -> "… °C" / "… °F"
1601                    if *name == "from_celsius" {
1602                        return with_parens_liberal(&args[0]) + m::space() + m::unit("°C");
1603                    } else if *name == "from_fahrenheit" {
1604                        return with_parens_liberal(&args[0]) + m::space() + m::unit("°F");
1605                    }
1606                    // °C(x) / celsius(x) / degree_celsius(x) -> "x -> °C"
1607                    // °F(x) / fahrenheit(x) / degree_fahrenheit(x) -> "x -> °F"
1608                    if *name == "°C" || *name == "celsius" || *name == "degree_celsius" {
1609                        return with_parens_liberal(&args[0])
1610                            + m::space()
1611                            + m::operator("->")
1612                            + m::space()
1613                            + m::unit("°C");
1614                    } else if *name == "°F" || *name == "fahrenheit" || *name == "degree_fahrenheit"
1615                    {
1616                        return with_parens_liberal(&args[0])
1617                            + m::space()
1618                            + m::operator("->")
1619                            + m::space()
1620                            + m::unit("°F");
1621                    }
1622                }
1623
1624                m::identifier(name.to_compact_string())
1625                    + m::operator("(")
1626                    + itertools::Itertools::intersperse(
1627                        args.iter().map(|e: &Expression| e.pretty_print()),
1628                        m::operator(",") + m::space(),
1629                    )
1630                    .sum()
1631                    + m::operator(")")
1632            }
1633            CallableCall {
1634                callable: expr,
1635                args,
1636                ..
1637            } => {
1638                // See above
1639                if args.len() == 1
1640                    && let Expression::Identifier { name, .. } = expr.as_ref()
1641                {
1642                    if *name == "°C" || *name == "celsius" || *name == "degree_celsius" {
1643                        return with_parens_liberal(&args[0])
1644                            + m::space()
1645                            + m::operator("->")
1646                            + m::space()
1647                            + m::unit("°C");
1648                    } else if *name == "°F" || *name == "fahrenheit" || *name == "degree_fahrenheit"
1649                    {
1650                        return with_parens_liberal(&args[0])
1651                            + m::space()
1652                            + m::operator("->")
1653                            + m::space()
1654                            + m::unit("°F");
1655                    }
1656                }
1657
1658                expr.pretty_print()
1659                    + m::operator("(")
1660                    + itertools::Itertools::intersperse(
1661                        args.iter().map(|e: &Expression| e.pretty_print()),
1662                        m::operator(",") + m::space(),
1663                    )
1664                    .sum()
1665                    + m::operator(")")
1666            }
1667            Boolean(_, val) => val.pretty_print(),
1668            String(_, parts) => parts.pretty_print(),
1669            Condition {
1670                condition,
1671                then_expr,
1672                else_expr,
1673                ..
1674            } => {
1675                m::keyword("if")
1676                    + m::space()
1677                    + with_parens(condition)
1678                    + m::space()
1679                    + m::keyword("then")
1680                    + m::space()
1681                    + with_parens(then_expr)
1682                    + m::space()
1683                    + m::keyword("else")
1684                    + m::space()
1685                    + with_parens(else_expr)
1686            }
1687            InstantiateStruct {
1688                fields,
1689                struct_info,
1690                ..
1691            } => {
1692                m::type_identifier(struct_info.name.clone())
1693                    + m::space()
1694                    + m::operator("{")
1695                    + if fields.is_empty() {
1696                        m::empty()
1697                    } else {
1698                        m::space()
1699                            + itertools::Itertools::intersperse(
1700                                fields.iter().map(|(n, e)| {
1701                                    m::identifier(n.to_compact_string())
1702                                        + m::operator(":")
1703                                        + m::space()
1704                                        + e.pretty_print()
1705                                }),
1706                                m::operator(",") + m::space(),
1707                            )
1708                            .sum()
1709                            + m::space()
1710                    }
1711                    + m::operator("}")
1712            }
1713            AccessField {
1714                expr, field_name, ..
1715            } => {
1716                expr.pretty_print()
1717                    + m::operator(".")
1718                    + m::identifier(field_name.to_compact_string())
1719            }
1720            List { elements, .. } => {
1721                m::operator("[")
1722                    + itertools::Itertools::intersperse(
1723                        elements.iter().map(|e| e.pretty_print()),
1724                        m::operator(",") + m::space(),
1725                    )
1726                    .sum()
1727                    + m::operator("]")
1728            }
1729            TypedHole(_, _) => m::operator("?"),
1730        }
1731    }
1732}
1733
1734#[cfg(test)]
1735mod tests {
1736    use super::*;
1737    use crate::ast::ReplaceSpans;
1738    use crate::markup::{Formatter, PlainTextFormatter};
1739    use crate::prefix_transformer::Transformer;
1740
1741    fn parse(code: &str) -> Statement<'_> {
1742        let statements = crate::parser::parse(
1743            "dimension Scalar = 1
1744                 dimension Length
1745                 dimension Time
1746                 dimension Mass
1747
1748                 fn sin(x: Scalar) -> Scalar
1749                 fn cos(x: Scalar) -> Scalar
1750                 fn asin(x: Scalar) -> Scalar
1751                 fn atan(x: Scalar) -> Scalar
1752                 fn atan2<T>(x: T, y: T) -> Scalar
1753                 fn sqrt(x) = x^(1/2)
1754                 let pi = 2 asin(1)
1755
1756                 @aliases(m: short)
1757                 @metric_prefixes
1758                 unit meter: Length
1759
1760                 @aliases(s: short)
1761                 @metric_prefixes
1762                 unit second: Time
1763
1764                 @aliases(g: short)
1765                 @metric_prefixes
1766                 unit gram: Mass
1767
1768                 @aliases(rad: short)
1769                 @metric_prefixes
1770                 unit radian: Scalar = meter / meter
1771
1772                 @aliases(°: none)
1773                 unit degree = 180/pi × radian
1774
1775                 @aliases(in: short)
1776                 unit inch = 0.0254 m
1777
1778                 @metric_prefixes
1779                 unit points
1780
1781                 struct Foo {foo: Length, bar: Time}
1782
1783                 let a = 1
1784                 let b = 1
1785                 let c = 1
1786                 let d = 1
1787                 let e = 1
1788                 let f = 1
1789                 let x = 1
1790                 let r = 2 m
1791                 let vol = 3 m^3
1792                 let density = 1000 kg / m^3
1793                 let länge = 1
1794                 let x_2 = 1
1795                 let µ = 1
1796                 let _prefixed = 1",
1797            0,
1798        )
1799        .unwrap()
1800        .into_iter()
1801        .chain(crate::parser::parse(code, 0).unwrap());
1802
1803        let mut transformer = Transformer::new();
1804        let transformed_statements = transformer.transform(statements).unwrap().replace_spans();
1805
1806        crate::typechecker::TypeChecker::default()
1807            .check(&transformed_statements)
1808            .unwrap()
1809            .last()
1810            .unwrap()
1811            .clone()
1812    }
1813
1814    fn pretty_print(stmt: &Statement) -> CompactString {
1815        let markup = stmt.pretty_print();
1816
1817        (PlainTextFormatter {}).format(&markup, false)
1818    }
1819
1820    fn equal_pretty(input: &str, expected: &str) {
1821        println!();
1822        println!("expected: '{expected}'");
1823        let actual = pretty_print(&parse(input));
1824        println!("actual:   '{actual}'");
1825        assert_eq!(actual, expected);
1826    }
1827
1828    #[test]
1829    fn pretty_print_basic() {
1830        equal_pretty("2+3", "2 + 3");
1831        equal_pretty("2*3", "2 × 3");
1832        equal_pretty("2^3", "2³");
1833        equal_pretty("2km", "2 kilometer");
1834        equal_pretty("2kilometer", "2 kilometer");
1835        equal_pretty("sin(30°)", "sin(30 degree)");
1836        equal_pretty("2*3*4", "2 × 3 × 4");
1837        equal_pretty("2*(3*4)", "2 × 3 × 4");
1838        equal_pretty("2+3+4", "2 + 3 + 4");
1839        equal_pretty("2+(3+4)", "2 + 3 + 4");
1840        equal_pretty("atan(30cm / 2m)", "atan(30 centimeter / 2 meter)");
1841        equal_pretty("1mrad -> °", "1 milliradian ➞ degree");
1842        equal_pretty("2km+2cm -> in", "2 kilometer + 2 centimeter ➞ inch");
1843        equal_pretty("2^3 + 4^5", "2³ + 4^5");
1844        equal_pretty("2^3 - 4^5", "2³ - 4^5");
1845        equal_pretty("2^3 * 4^5", "2³ × 4^5");
1846        equal_pretty("2 * 3 + 4 * 5", "2 × 3 + 4 × 5");
1847        equal_pretty("2 * 3 / 4", "2 × 3 / 4");
1848        equal_pretty("123.123 km² / s²", "123.123 × kilometer² / second²");
1849    }
1850
1851    fn roundtrip_check(code: &str) {
1852        println!("Roundtrip check for code = '{code}'");
1853        let ast1 = parse(code);
1854        let code_pretty = pretty_print(&ast1);
1855        println!("     pretty printed code = '{code_pretty}'");
1856        let ast2 = parse(&code_pretty);
1857        assert_eq!(ast1, ast2);
1858    }
1859
1860    #[test]
1861    fn pretty_print_roundtrip_check() {
1862        roundtrip_check("1.0");
1863        roundtrip_check("2");
1864        roundtrip_check("1 + 2");
1865
1866        roundtrip_check("-2.3e-12387");
1867        roundtrip_check("2.3e-12387");
1868        roundtrip_check("18379173");
1869        roundtrip_check("2+3");
1870        roundtrip_check("2+3*5");
1871        roundtrip_check("-3^4+2/(4+2*3)");
1872        roundtrip_check("1-2-3-4-(5-6-7)");
1873        roundtrip_check("1/2/3/4/(5/6/7)");
1874        roundtrip_check("kilogram");
1875        roundtrip_check("2meter/second");
1876        roundtrip_check("a+b*c^d-e*f");
1877        roundtrip_check("sin(x)^3");
1878        roundtrip_check("sin(cos(atan(x)+2))^3");
1879        roundtrip_check("2^3^4^5");
1880        roundtrip_check("(2^3)^(4^5)");
1881        roundtrip_check("sqrt(1.4^2 + 1.5^2) * cos(pi/3)^2");
1882        roundtrip_check("40 kilometer * 9.8meter/second^2 * 150centimeter");
1883        roundtrip_check("4/3 * pi * r³");
1884        roundtrip_check("vol * density -> kilogram");
1885        roundtrip_check("atan(30 centimeter / 2 meter)");
1886        roundtrip_check("500kilometer/second -> centimeter/second");
1887        roundtrip_check("länge * x_2 * µ * _prefixed");
1888        roundtrip_check("2meter^3");
1889        roundtrip_check("(2meter)^3");
1890        roundtrip_check("-sqrt(-30meter^3)");
1891        roundtrip_check("-3^4");
1892        roundtrip_check("(-3)^4");
1893        roundtrip_check("atan2(2,3)");
1894        roundtrip_check("2^3!");
1895        roundtrip_check("-3!");
1896        roundtrip_check("(-3)!");
1897        roundtrip_check("megapoints");
1898        roundtrip_check("Foo { foo: 1 meter, bar: 1 second }");
1899        roundtrip_check("\"foo\"");
1900        roundtrip_check("\"newline: \\n\"");
1901    }
1902
1903    #[test]
1904    fn pretty_print_dexpr() {
1905        roundtrip_check("unit z: Length");
1906        roundtrip_check("unit z: Length * Time");
1907        roundtrip_check("unit z: Length * Time^2");
1908        roundtrip_check("unit z: Length^-3 * Time^2");
1909        roundtrip_check("unit z: Length / Time");
1910        roundtrip_check("unit z: Length / Time^2");
1911        roundtrip_check("unit z: Length / Time^(-2)");
1912        roundtrip_check("unit z: Length / (Time * Mass)");
1913        roundtrip_check("unit z: Length^5 * Time^4 / (Time^2 * Mass^3)");
1914    }
1915}