1use std::sync::Arc;
2
3use compact_str::{CompactString, ToCompactString, format_compact};
4use indexmap::IndexMap;
5use itertools::Itertools;
6
7use crate::arithmetic::Exponent;
8pub use crate::ast::{BinaryOperator, TypeExpression, UnaryOperator};
9use crate::ast::{ProcedureKind, TypeAnnotation, TypeParameterBound};
10use crate::dimension::DimensionRegistry;
11use crate::pretty_print::escape_numbat_string;
12use crate::traversal::{ForAllExpressions, ForAllTypeSchemes};
13use crate::type_variable::TypeVariable;
14use crate::typechecker::TypeCheckError;
15use crate::typechecker::qualified_type::QualifiedType;
16use crate::typechecker::type_scheme::TypeScheme;
17use crate::{BaseRepresentation, BaseRepresentationFactor, markup as m};
18use crate::{
19 decorator::Decorator, markup::Markup, number::Number, prefix::Prefix,
20 prefix_parser::AcceptsPrefix, pretty_print::PrettyPrint, span::Span,
21};
22use num_traits::{CheckedAdd, CheckedMul};
23
24#[derive(Clone, Debug, PartialEq, Eq)]
26pub enum DTypeFactor {
27 TVar(TypeVariable),
28 TPar(CompactString),
29 BaseDimension(CompactString),
30}
31
32impl DTypeFactor {
33 pub fn name(&self) -> &str {
34 match self {
35 DTypeFactor::TVar(TypeVariable::Named(name)) => name,
36 DTypeFactor::TVar(TypeVariable::Quantified(_)) => unreachable!(),
37 DTypeFactor::TPar(name) => name,
38 DTypeFactor::BaseDimension(name) => name,
39 }
40 }
41}
42
43type DtypeFactorPower = (DTypeFactor, Exponent);
44
45#[derive(Clone, Debug, PartialEq, Eq)]
46pub struct DType {
47 factors: Arc<Vec<DtypeFactorPower>>,
49}
50
51impl DType {
52 pub fn factors(&self) -> &[DtypeFactorPower] {
53 &self.factors
54 }
55
56 pub fn into_factors(self) -> Arc<Vec<DtypeFactorPower>> {
57 self.factors
58 }
59
60 pub fn from_factors(factors: Arc<Vec<DtypeFactorPower>>) -> DType {
61 let mut dtype = DType { factors };
62 dtype.canonicalize();
63 dtype
64 }
65
66 pub fn scalar() -> DType {
67 DType::from_factors(Arc::new(vec![]))
68 }
69
70 pub fn is_scalar(&self) -> bool {
71 self == &Self::scalar()
72 }
73
74 pub fn to_readable_type(&self, registry: &DimensionRegistry) -> m::Markup {
75 if self.is_scalar() {
76 return m::type_identifier("Scalar");
77 }
78
79 let mut names = vec![];
80
81 if self.factors.len() == 1 && self.factors[0].1 == Exponent::from_integer(1) {
82 names.push(self.factors[0].0.name().to_compact_string());
83 }
84
85 let base_representation = self.to_base_representation();
86 names.extend(registry.get_derived_entry_names_for(&base_representation));
87 match &names[..] {
88 [] => self.pretty_print(),
89 [single] => m::type_identifier(single.to_compact_string()),
90 multiple => Itertools::intersperse(
91 multiple.iter().cloned().map(m::type_identifier),
92 m::dimmed(" or "),
93 )
94 .sum(),
95 }
96 }
97
98 pub fn is_time_dimension(&self) -> bool {
102 *self == DType::base_dimension("Time")
103 }
104
105 pub fn from_type_variable(v: TypeVariable) -> DType {
106 DType::from_factors(Arc::new(vec![(
107 DTypeFactor::TVar(v),
108 Exponent::from_integer(1),
109 )]))
110 }
111
112 pub fn from_type_parameter(name: CompactString) -> DType {
113 DType::from_factors(Arc::new(vec![(
114 DTypeFactor::TPar(name),
115 Exponent::from_integer(1),
116 )]))
117 }
118
119 pub fn deconstruct_as_single_type_variable(&self) -> Option<TypeVariable> {
120 match &self.factors[..] {
121 [(DTypeFactor::TVar(v), exponent)] if exponent == &Exponent::from_integer(1) => {
122 Some(v.clone())
123 }
124 _ => None,
125 }
126 }
127
128 pub fn from_tgen(i: usize) -> DType {
129 DType::from_factors(Arc::new(vec![(
130 DTypeFactor::TVar(TypeVariable::Quantified(i)),
131 Exponent::from_integer(1),
132 )]))
133 }
134
135 pub fn base_dimension(name: &str) -> DType {
136 DType::from_factors(Arc::new(vec![(
137 DTypeFactor::BaseDimension(name.into()),
138 Exponent::from_integer(1),
139 )]))
140 }
141
142 fn canonicalize(&mut self) {
143 self.try_canonicalize()
144 .expect("overflow in dimension type exponent computation");
145 }
146
147 fn try_canonicalize(&mut self) -> Option<()> {
149 Arc::make_mut(&mut self.factors).sort_by(|(f1, _), (f2, _)| match (f1, f2) {
151 (DTypeFactor::TVar(v1), DTypeFactor::TVar(v2)) => v1.cmp(v2),
152 (DTypeFactor::TVar(_), _) => std::cmp::Ordering::Less,
153
154 (DTypeFactor::BaseDimension(d1), DTypeFactor::BaseDimension(d2)) => d1.cmp(d2),
155 (DTypeFactor::BaseDimension(_), DTypeFactor::TVar(_)) => std::cmp::Ordering::Greater,
156 (DTypeFactor::BaseDimension(_), DTypeFactor::TPar(_)) => std::cmp::Ordering::Less,
157
158 (DTypeFactor::TPar(p1), DTypeFactor::TPar(p2)) => p1.cmp(p2),
159 (DTypeFactor::TPar(_), _) => std::cmp::Ordering::Greater,
160 });
161
162 let mut new_factors: Vec<DtypeFactorPower> = Vec::new();
164 for (f, n) in self.factors.iter() {
165 if let Some((last_f, last_n)) = new_factors.last_mut()
166 && f == last_f
167 {
168 *last_n = last_n.checked_add(n)?;
169 continue;
170 }
171 new_factors.push((f.clone(), *n));
172 }
173
174 new_factors.retain(|(_, n)| *n != Exponent::from_integer(0));
176
177 self.factors = Arc::new(new_factors);
178 Some(())
179 }
180
181 pub fn try_from_factors(factors: Arc<Vec<DtypeFactorPower>>) -> Option<DType> {
183 let mut dtype = DType { factors };
184 dtype.try_canonicalize()?;
185 Some(dtype)
186 }
187
188 pub fn multiply(&self, other: &DType) -> DType {
189 let mut factors = self.factors.clone();
190 Arc::make_mut(&mut factors).extend(other.factors.iter().cloned());
191 DType::from_factors(factors)
192 }
193
194 pub fn try_multiply(&self, other: &DType) -> Option<DType> {
196 let mut factors = self.factors.clone();
197 Arc::make_mut(&mut factors).extend(other.factors.iter().cloned());
198 DType::try_from_factors(factors)
199 }
200
201 pub fn power(&self, n: Exponent) -> DType {
202 let factors = self
203 .factors
204 .iter()
205 .map(|(f, m)| (f.clone(), n * m))
206 .collect();
207 DType::from_factors(Arc::new(factors))
208 }
209
210 pub fn try_power(&self, n: Exponent) -> Option<DType> {
212 let factors: Option<Vec<_>> = self
213 .factors
214 .iter()
215 .map(|(f, m)| n.checked_mul(m).map(|exp| (f.clone(), exp)))
216 .collect();
217 factors.and_then(|f| DType::try_from_factors(Arc::new(f)))
218 }
219
220 pub fn inverse(&self) -> DType {
221 self.power(-Exponent::from_integer(1))
222 }
223
224 pub fn divide(&self, other: &DType) -> DType {
225 self.multiply(&other.inverse())
226 }
227
228 pub fn try_divide(&self, other: &DType) -> Option<DType> {
230 self.try_multiply(&other.inverse())
231 }
232
233 pub fn type_variables(&self, including_type_parameters: bool) -> Vec<TypeVariable> {
234 let mut vars: Vec<_> = self
235 .factors
236 .iter()
237 .filter_map(|(f, _)| match f {
238 DTypeFactor::TVar(v) => Some(v.clone()),
239 DTypeFactor::TPar(v) => {
240 if including_type_parameters {
241 Some(TypeVariable::new(v))
242 } else {
243 None
244 }
245 }
246 DTypeFactor::BaseDimension(_) => None,
247 })
248 .collect();
249 vars.sort();
250 vars.dedup();
251 vars
252 }
253
254 pub fn contains(&self, name: &TypeVariable, including_type_parameters: bool) -> bool {
255 self.type_variables(including_type_parameters)
256 .contains(name)
257 }
258
259 pub fn split_first_factor(&self) -> Option<(&DtypeFactorPower, &[DtypeFactorPower])> {
260 self.factors.split_first()
261 }
262
263 fn instantiate(&self, type_variables: &[TypeVariable]) -> DType {
264 let mut factors = Vec::new();
265
266 for (f, n) in self.factors.iter() {
267 match f {
268 DTypeFactor::TVar(TypeVariable::Quantified(i)) => {
269 factors.push((DTypeFactor::TVar(type_variables[*i].clone()), *n));
270 }
271 _ => {
272 factors.push((f.clone(), *n));
273 }
274 }
275 }
276 Self::from_factors(Arc::new(factors))
277 }
278
279 pub fn to_base_representation(&self) -> BaseRepresentation {
280 let mut factors = vec![];
281 for (f, n) in self.factors.iter() {
282 match f {
283 DTypeFactor::BaseDimension(name) => {
284 factors.push(BaseRepresentationFactor(name.clone(), *n));
285 }
286 DTypeFactor::TVar(TypeVariable::Named(name)) => {
287 factors.push(BaseRepresentationFactor(name.clone(), *n));
288 }
289 DTypeFactor::TVar(TypeVariable::Quantified(id)) => {
290 factors.push(BaseRepresentationFactor(format!("?{id}").into(), *n));
293 }
294 DTypeFactor::TPar(name) => {
295 factors.push(BaseRepresentationFactor(name.clone(), *n));
296 }
297 }
298 }
299 BaseRepresentation::from_factors(factors)
300 }
301}
302
303impl PrettyPrint for DType {
304 fn pretty_print(&self) -> Markup {
305 self.to_base_representation().pretty_print()
306 }
307}
308
309impl std::fmt::Display for DType {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 write!(f, "{}", self.pretty_print())
312 }
313}
314
315impl From<BaseRepresentation> for DType {
316 fn from(base_representation: BaseRepresentation) -> Self {
317 let factors = base_representation
318 .into_iter()
319 .map(|BaseRepresentationFactor(name, exp)| (DTypeFactor::BaseDimension(name), exp))
320 .collect();
321 DType::from_factors(Arc::new(factors))
322 }
323}
324
325#[derive(Debug, Clone, PartialEq, Eq)]
327pub enum StructKind {
328 Definition(Vec<(Span, CompactString, Option<TypeParameterBound>)>),
330 Instance(Vec<Type>),
332}
333
334#[derive(Debug, Clone, PartialEq, Eq)]
335pub struct StructInfo {
336 pub definition_span: Span,
337 pub name: CompactString,
338 pub kind: StructKind,
339 pub fields: IndexMap<CompactString, (Span, Type)>,
340}
341
342#[derive(Debug, Clone, PartialEq, Eq)]
351pub enum Type {
352 TVar(TypeVariable),
353 TPar(CompactString),
354 Dimension(DType),
355 Boolean,
356 String,
357 DateTime,
358 Fn(Vec<Type>, Box<Type>),
359 Struct(Box<StructInfo>),
360 List(Box<Type>),
361}
362
363impl std::fmt::Display for Type {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 match self {
366 Type::TVar(TypeVariable::Named(name)) => write!(f, "{name}"),
367 Type::TVar(TypeVariable::Quantified(_)) => {
368 unreachable!("Quantified types should not be printed")
369 }
370 Type::TPar(name) => write!(f, "{name}"),
371 Type::Dimension(d) => d.fmt(f),
372 Type::Boolean => write!(f, "Bool"),
373 Type::String => write!(f, "String"),
374 Type::DateTime => write!(f, "DateTime"),
375 Type::Fn(param_types, return_type) => {
376 write!(
377 f,
378 "Fn[({ps}) -> {return_type}]",
379 ps = param_types.iter().map(|p| p.to_string()).join(", ")
380 )
381 }
382 Type::Struct(info) => {
383 write!(f, "{}", info.name)?;
384 if let StructKind::Instance(type_args) = &info.kind
385 && !type_args.is_empty()
386 {
387 write!(
388 f,
389 "<{}>",
390 type_args.iter().map(|t| t.to_string()).join(", ")
391 )?;
392 }
393 write!(
394 f,
395 " {{{}}}",
396 info.fields
397 .iter()
398 .map(|(n, (_, t))| n.to_string() + ": " + &t.to_string())
399 .join(", ")
400 )
401 }
402 Type::List(element_type) => write!(f, "List<{element_type}>"),
403 }
404 }
405}
406
407impl PrettyPrint for Type {
408 fn pretty_print(&self) -> Markup {
409 match self {
410 Type::TVar(TypeVariable::Named(name)) => m::type_identifier(name.to_compact_string()),
411 Type::TVar(TypeVariable::Quantified(_)) => {
412 unreachable!("Quantified types should not be printed")
413 }
414 Type::TPar(name) => m::type_identifier(name.clone()),
415 Type::Dimension(d) => d.pretty_print(),
416 Type::Boolean => m::type_identifier("Bool"),
417 Type::String => m::type_identifier("String"),
418 Type::DateTime => m::type_identifier("DateTime"),
419 Type::Fn(param_types, return_type) => {
420 m::type_identifier("Fn")
421 + m::operator("[(")
422 + Itertools::intersperse(
423 param_types.iter().map(|t| t.pretty_print()),
424 m::operator(",") + m::space(),
425 )
426 .sum()
427 + m::operator(")")
428 + m::space()
429 + m::operator("->")
430 + m::space()
431 + return_type.pretty_print()
432 + m::operator("]")
433 }
434 Type::Struct(info) => {
435 let mut markup = m::type_identifier(info.name.clone());
436 if let StructKind::Instance(type_args) = &info.kind
437 && !type_args.is_empty()
438 {
439 markup += m::operator("<");
440 markup += Itertools::intersperse(
441 type_args.iter().map(|t| t.pretty_print()),
442 m::operator(",") + m::space(),
443 )
444 .sum();
445 markup += m::operator(">");
446 }
447 markup
448 }
449 Type::List(element_type) => {
450 m::type_identifier("List")
451 + m::operator("<")
452 + element_type.pretty_print()
453 + m::operator(">")
454 }
455 }
456 }
457}
458
459impl Type {
460 pub fn to_readable_type(&self, registry: &DimensionRegistry) -> Markup {
461 match self {
462 Type::Dimension(d) => d.to_readable_type(registry),
463 Type::Struct(info) => {
464 let mut markup = m::type_identifier(info.name.clone());
465 if let StructKind::Instance(type_args) = &info.kind
466 && !type_args.is_empty()
467 {
468 markup += m::operator("<");
469 markup += Itertools::intersperse(
470 type_args.iter().map(|t| t.to_readable_type(registry)),
471 m::operator(",") + m::space(),
472 )
473 .sum();
474 markup += m::operator(">");
475 }
476 markup
477 }
478 Type::List(element_type) => {
479 m::type_identifier("List")
480 + m::operator("<")
481 + element_type.to_readable_type(registry)
482 + m::operator(">")
483 }
484 _ => self.pretty_print(),
485 }
486 }
487
488 pub fn scalar() -> Type {
489 Type::Dimension(DType::scalar())
490 }
491
492 pub fn is_dtype(&self) -> bool {
493 matches!(self, Type::Dimension(..))
494 }
495
496 pub fn is_fn_type(&self) -> bool {
497 matches!(self, Type::Fn(..))
498 }
499
500 pub(crate) fn has_incompatible_constructor(&self, other: &Type) -> bool {
504 use Type::*;
505
506 let is_dimension_with_only_tvar = |t: &Type| matches!(t, Dimension(d) if d.deconstruct_as_single_type_variable().is_some());
508
509 match (self, other) {
510 (TVar(_), _) | (_, TVar(_)) | (TPar(_), _) | (_, TPar(_)) => false,
512
513 (t1, t2) if is_dimension_with_only_tvar(t1) || is_dimension_with_only_tvar(t2) => false,
515
516 (Dimension(_), Dimension(_))
518 | (Boolean, Boolean)
519 | (String, String)
520 | (DateTime, DateTime)
521 | (Fn(_, _), Fn(_, _))
522 | (Struct(_), Struct(_))
523 | (List(_), List(_)) => false,
524
525 _ => true,
527 }
528 }
529
530 pub(crate) fn type_variables(&self, including_type_parameters: bool) -> Vec<TypeVariable> {
531 match self {
532 Type::TVar(v) => vec![v.clone()],
533 Type::TPar(n) => {
534 if including_type_parameters {
535 vec![TypeVariable::new(n)]
536 } else {
537 vec![]
538 }
539 }
540 Type::Dimension(d) => d.type_variables(including_type_parameters),
541 Type::Boolean | Type::String | Type::DateTime => vec![],
542 Type::Fn(param_types, return_type) => {
543 let mut vars = return_type.type_variables(including_type_parameters);
544 for param_type in param_types {
545 vars.extend(param_type.type_variables(including_type_parameters));
546 }
547 vars.sort();
548 vars.dedup();
549 vars
550 }
551 Type::Struct(info) => {
552 let mut vars = vec![];
553 for (_, (_, t)) in &info.fields {
554 vars.extend(t.type_variables(including_type_parameters));
555 }
556 vars
557 }
558 Type::List(element_type) => element_type.type_variables(including_type_parameters),
559 }
560 }
561
562 pub(crate) fn contains(&self, x: &TypeVariable, including_type_parameters: bool) -> bool {
563 self.type_variables(including_type_parameters).contains(x)
564 }
565
566 pub(crate) fn is_closed(&self) -> bool {
568 self.type_variables(false).is_empty()
569 }
570
571 pub(crate) fn instantiate(&self, type_variables: &[TypeVariable]) -> Type {
572 match self {
573 Type::TVar(TypeVariable::Quantified(i)) => Type::TVar(type_variables[*i].clone()),
574 Type::TVar(v) => Type::TVar(v.clone()),
575 Type::TPar(n) => Type::TPar(n.clone()),
576 Type::Dimension(d) => Type::Dimension(d.instantiate(type_variables)),
577 Type::Boolean | Type::String | Type::DateTime => self.clone(),
578 Type::Fn(param_types, return_type) => Type::Fn(
579 param_types
580 .iter()
581 .map(|t| t.instantiate(type_variables))
582 .collect(),
583 Box::new(return_type.instantiate(type_variables)),
584 ),
585 Type::Struct(info) => {
586 let instantiated_fields = info
587 .fields
588 .iter()
589 .map(|(name, (span, field_type))| {
590 (
591 name.clone(),
592 (*span, field_type.instantiate(type_variables)),
593 )
594 })
595 .collect();
596 let instantiated_kind = match &info.kind {
597 StructKind::Definition(params) => StructKind::Definition(params.clone()),
598 StructKind::Instance(type_args) => StructKind::Instance(
599 type_args
600 .iter()
601 .map(|t| t.instantiate(type_variables))
602 .collect(),
603 ),
604 };
605 Type::Struct(Box::new(StructInfo {
606 definition_span: info.definition_span,
607 name: info.name.clone(),
608 kind: instantiated_kind,
609 fields: instantiated_fields,
610 }))
611 }
612 Type::List(element_type) => {
613 Type::List(Box::new(element_type.instantiate(type_variables)))
614 }
615 }
616 }
617
618 pub(crate) fn is_scalar(&self) -> bool {
619 match self {
620 Type::Dimension(d) => d.is_scalar(),
621 _ => false,
622 }
623 }
624}
625
626#[derive(Debug, Clone, PartialEq)]
627pub enum StringPart<'a> {
628 Fixed(CompactString),
629 Interpolation {
630 span: Span,
631 expr: Box<Expression<'a>>,
632 format_specifiers: Option<&'a str>,
633 },
634}
635
636impl PrettyPrint for StringPart<'_> {
637 fn pretty_print(&self) -> Markup {
638 match self {
639 StringPart::Fixed(s) => m::string(escape_numbat_string(s)),
640 StringPart::Interpolation {
641 span: _,
642 expr,
643 format_specifiers,
644 } => {
645 let mut markup = m::operator("{") + expr.pretty_print();
646
647 if let Some(format_specifiers) = format_specifiers {
648 markup += m::text(format_specifiers.to_compact_string());
649 }
650
651 markup += m::operator("}");
652
653 markup
654 }
655 }
656 }
657}
658
659impl PrettyPrint for &Vec<StringPart<'_>> {
660 fn pretty_print(&self) -> Markup {
661 m::operator("\"") + self.iter().map(|p| p.pretty_print()).sum() + m::operator("\"")
662 }
663}
664
665#[derive(Debug, Clone, PartialEq)]
666pub enum Expression<'a> {
667 Scalar {
668 span: Span,
669 value: Number,
670 type_scheme: TypeScheme,
671 },
672 Identifier {
673 span: Span,
674 name: &'a str,
675 type_scheme: TypeScheme,
676 },
677 UnitIdentifier {
678 span: Span,
679 prefix: Prefix,
680 name: CompactString,
681 full_name: CompactString,
682 type_scheme: TypeScheme,
683 },
684 UnaryOperator {
685 span: Span,
686 op: UnaryOperator,
687 expr: Box<Expression<'a>>,
688 type_scheme: TypeScheme,
689 },
690 BinaryOperator {
691 op_span: Option<Span>,
692 op: BinaryOperator,
693 lhs: Box<Expression<'a>>,
694 rhs: Box<Expression<'a>>,
695 type_scheme: TypeScheme,
696 },
697 BinaryOperatorForDate {
699 op_span: Option<Span>,
700 op: BinaryOperator,
701 lhs: Box<Expression<'a>>,
703 rhs: Box<Expression<'a>>,
705 type_scheme: TypeScheme,
706 },
707 FunctionCall {
709 full_span: Span,
710 ident_span: Span,
711 name: &'a str,
712 args: Vec<Expression<'a>>,
713 type_scheme: TypeScheme,
714 },
715 CallableCall {
717 full_span: Span,
718 callable: Box<Expression<'a>>,
719 args: Vec<Expression<'a>>,
720 type_scheme: TypeScheme,
721 },
722 Boolean(Span, bool),
723 Condition {
724 span: Span,
725 condition: Box<Expression<'a>>,
726 then_expr: Box<Expression<'a>>,
727 else_expr: Box<Expression<'a>>,
728 },
729 String(Span, Vec<StringPart<'a>>),
730 InstantiateStruct {
731 span: Span,
732 fields: Vec<(&'a str, Expression<'a>)>,
733 struct_info: StructInfo,
734 },
735 AccessField {
736 full_span: Span,
737 ident_span: Span,
738 expr: Box<Expression<'a>>,
739 field_name: &'a str,
740 struct_type: TypeScheme,
741 field_type: TypeScheme,
742 },
743 List {
744 span: Span,
745 elements: Vec<Expression<'a>>,
746 type_scheme: TypeScheme,
747 },
748 TypedHole(Span, TypeScheme),
749}
750
751impl Expression<'_> {
752 pub fn full_span(&self) -> Span {
753 match self {
754 Expression::Scalar { span, .. } => *span,
755 Expression::Identifier { span, .. } => *span,
756 Expression::UnitIdentifier { span, .. } => *span,
757 Expression::UnaryOperator { span, expr, .. } => span.extend(&expr.full_span()),
758 Expression::BinaryOperator {
759 op_span, lhs, rhs, ..
760 } => {
761 let mut span = lhs.full_span().extend(&rhs.full_span());
762 if let Some(op_span) = op_span {
763 span = span.extend(op_span);
764 }
765 span
766 }
767 Expression::BinaryOperatorForDate {
768 op_span, lhs, rhs, ..
769 } => {
770 let mut span = lhs.full_span().extend(&rhs.full_span());
771 if let Some(op_span) = op_span {
772 span = span.extend(op_span);
773 }
774 span
775 }
776 Expression::FunctionCall { full_span, .. } => *full_span,
777 Expression::CallableCall { full_span, .. } => *full_span,
778 Expression::Boolean(span, _) => *span,
779 Expression::Condition {
780 span, else_expr, ..
781 } => span.extend(&else_expr.full_span()),
782 Expression::String(span, _) => *span,
783 Expression::InstantiateStruct { span, .. } => *span,
784 Expression::AccessField { full_span, .. } => *full_span,
785 Expression::List { span, .. } => *span,
786 Expression::TypedHole(span, _) => *span,
787 }
788 }
789}
790
791#[derive(Debug, Clone, PartialEq)]
792pub struct DefineVariable<'a> {
793 pub name: &'a str,
794 pub decorators: Vec<Decorator<'a>>,
795 pub expr: Expression<'a>,
796 pub type_annotation: Option<TypeAnnotation>,
797 pub type_scheme: TypeScheme,
798 pub readable_type: Markup,
799}
800
801#[derive(Debug, Clone, PartialEq)]
802pub enum Statement<'a> {
803 Expression(Expression<'a>),
804 DefineVariable(DefineVariable<'a>),
805 DefineFunction {
806 function_name: &'a str,
807 decorators: Vec<Decorator<'a>>,
808 type_parameters: Vec<(&'a str, Option<TypeParameterBound>)>,
809 parameters: Vec<(
810 Span, &'a str, Option<TypeAnnotation>, Markup, )>,
815 body: Option<Expression<'a>>,
816 local_variables: Vec<DefineVariable<'a>>,
817 fn_type: TypeScheme,
818 return_type_annotation: Option<TypeAnnotation>,
819 readable_return_type: Markup,
820 },
821 DefineDimension(&'a str, Vec<TypeExpression>),
822 DefineBaseUnit {
823 name: &'a str,
824 identifier_span: Span,
825 decorators: Vec<Decorator<'a>>,
826 type_annotation: Option<TypeAnnotation>,
827 type_scheme: TypeScheme,
828 },
829 DefineDerivedUnit {
830 name: &'a str,
831 identifier_span: Span,
832 expr: Expression<'a>,
833 decorators: Vec<Decorator<'a>>,
834 type_annotation: Option<TypeAnnotation>,
835 type_scheme: TypeScheme,
836 readable_type: Markup,
837 },
838 ProcedureCall {
839 kind: ProcedureKind,
840 span: Span,
841 args: Vec<Expression<'a>>,
842 },
843 DefineStruct(StructInfo),
844}
845
846impl Statement<'_> {
847 pub fn as_expression(&self) -> Option<&Expression<'_>> {
848 if let Self::Expression(v) = self {
849 Some(v)
850 } else {
851 None
852 }
853 }
854
855 pub(crate) fn generalize_types(&mut self, dtype_variables: &[TypeVariable]) {
856 self.for_all_type_schemes(&mut |type_: &mut TypeScheme| type_.generalize(dtype_variables));
857 }
858
859 fn create_readable_type(
860 registry: &DimensionRegistry,
861 type_: &TypeScheme,
862 annotation: &Option<TypeAnnotation>,
863 with_quantifiers: bool,
864 ) -> Markup {
865 if let Some(annotation) = annotation {
866 annotation.pretty_print()
867 } else {
868 type_.to_readable_type(registry, with_quantifiers)
869 }
870 }
871
872 pub(crate) fn update_readable_types(&mut self, registry: &DimensionRegistry) {
873 match self {
874 Statement::Expression(_) => {}
875 Statement::DefineVariable(DefineVariable {
876 type_annotation,
877 type_scheme,
878 readable_type,
879 ..
880 }) => {
881 *readable_type =
882 Self::create_readable_type(registry, type_scheme, type_annotation, true);
883 }
884 Statement::DefineFunction {
885 type_parameters,
886 parameters,
887 local_variables,
888 fn_type,
889 return_type_annotation,
890 readable_return_type,
891 ..
892 } => {
893 let (fn_type, _) =
894 fn_type.instantiate_for_printing(Some(type_parameters.iter().map(|(n, _)| *n)));
895
896 for DefineVariable {
897 type_annotation,
898 type_scheme,
899 readable_type,
900 ..
901 } in local_variables
902 {
903 *readable_type =
904 Self::create_readable_type(registry, type_scheme, type_annotation, false);
905 }
906
907 let Type::Fn(parameter_types, return_type) = fn_type.inner else {
908 unreachable!("Expected a function type")
909 };
910
911 *readable_return_type = Self::create_readable_type(
912 registry,
913 &TypeScheme::concrete(*return_type),
914 return_type_annotation,
915 false,
916 );
917
918 for ((_, _, type_annotation, readable_parameter_type), parameter_type) in
919 parameters.iter_mut().zip(parameter_types.iter())
920 {
921 *readable_parameter_type = Self::create_readable_type(
922 registry,
923 &TypeScheme::concrete(parameter_type.clone()),
924 type_annotation,
925 false,
926 );
927 }
928 }
929 Statement::DefineDimension(_, _) => {}
930 Statement::DefineBaseUnit { .. } => {}
931 Statement::DefineDerivedUnit {
932 type_annotation,
933 type_scheme,
934 readable_type,
935 ..
936 } => {
937 *readable_type =
938 Self::create_readable_type(registry, type_scheme, type_annotation, false);
939 }
940 Statement::ProcedureCall { .. } => {}
941 Statement::DefineStruct(_) => {}
942 }
943 }
944
945 pub(crate) fn exponents_for(&mut self, tv: &TypeVariable) -> Vec<Exponent> {
946 let mut exponents = vec![];
948 self.for_all_type_schemes(&mut |type_: &mut TypeScheme| {
949 if let Type::Dimension(dtype) = type_.unsafe_as_concrete() {
950 for (factor, exp) in dtype.factors.iter() {
951 if factor == &DTypeFactor::TVar(tv.clone()) {
952 exponents.push(*exp)
953 }
954 }
955 }
956 });
957 exponents
958 }
959
960 pub(crate) fn find_typed_hole(
961 &self,
962 ) -> Result<Option<(Span, TypeScheme)>, Box<TypeCheckError>> {
963 let mut hole = None;
964 let mut found_multiple_holes = false;
965 self.for_all_expressions(&mut |expr| {
966 if let Expression::TypedHole(span, type_) = expr {
967 if hole.is_some() {
968 found_multiple_holes = true;
969 }
970 hole = Some((*span, type_.clone()))
971 }
972 });
973
974 if found_multiple_holes {
975 Err(Box::new(TypeCheckError::MultipleTypedHoles(
976 hole.unwrap().0,
977 )))
978 } else {
979 Ok(hole)
980 }
981 }
982
983 pub(crate) fn local_bindings(&self) -> Vec<(&str, TypeScheme)> {
986 match self {
987 Statement::DefineFunction {
988 parameters,
989 local_variables,
990 fn_type,
991 ..
992 } => {
993 let mut bindings = Vec::new();
994
995 if let TypeScheme::Concrete(Type::Fn(param_types, _))
996 | TypeScheme::Quantified(
997 _,
998 crate::typechecker::qualified_type::QualifiedType {
999 inner: Type::Fn(param_types, _),
1000 ..
1001 },
1002 ) = fn_type
1003 {
1004 for ((_, param_name, _, _), param_type) in
1005 parameters.iter().zip(param_types.iter())
1006 {
1007 bindings
1008 .push((*param_name, TypeScheme::make_quantified(param_type.clone())));
1009 }
1010 }
1011
1012 for DefineVariable {
1013 name, type_scheme, ..
1014 } in local_variables
1015 {
1016 bindings.push((*name, type_scheme.clone()));
1017 }
1018
1019 bindings
1020 }
1021 _ => Vec::new(),
1022 }
1023 }
1024}
1025
1026impl Expression<'_> {
1027 pub fn get_type(&self) -> Type {
1028 match self {
1029 Expression::Scalar { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1030 Expression::Identifier { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1031 Expression::UnitIdentifier { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1032 Expression::UnaryOperator { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1033 Expression::BinaryOperator { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1034 Expression::BinaryOperatorForDate { type_scheme, .. } => {
1035 type_scheme.unsafe_as_concrete()
1036 }
1037 Expression::FunctionCall { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1038 Expression::CallableCall { type_scheme, .. } => type_scheme.unsafe_as_concrete(),
1039 Expression::Boolean(_, _) => Type::Boolean,
1040 Expression::Condition { then_expr, .. } => then_expr.get_type(),
1041 Expression::String(_, _) => Type::String,
1042 Expression::InstantiateStruct { struct_info, .. } => {
1043 Type::Struct(Box::new(struct_info.clone()))
1044 }
1045 Expression::AccessField { field_type, .. } => field_type.unsafe_as_concrete(),
1046 Expression::List { type_scheme, .. } => {
1047 Type::List(Box::new(type_scheme.unsafe_as_concrete()))
1048 }
1049 Expression::TypedHole(_, type_) => type_.unsafe_as_concrete(),
1050 }
1051 }
1052
1053 pub fn get_type_scheme(&self) -> TypeScheme {
1054 match self {
1055 Expression::Scalar { type_scheme, .. } => type_scheme.clone(),
1056 Expression::Identifier { type_scheme, .. } => type_scheme.clone(),
1057 Expression::UnitIdentifier { type_scheme, .. } => type_scheme.clone(),
1058 Expression::UnaryOperator { type_scheme, .. } => type_scheme.clone(),
1059 Expression::BinaryOperator { type_scheme, .. } => type_scheme.clone(),
1060 Expression::BinaryOperatorForDate { type_scheme, .. } => type_scheme.clone(),
1061 Expression::FunctionCall { type_scheme, .. } => type_scheme.clone(),
1062 Expression::CallableCall { type_scheme, .. } => type_scheme.clone(),
1063 Expression::Boolean(_, _) => TypeScheme::make_quantified(Type::Boolean),
1064 Expression::Condition { then_expr, .. } => then_expr.get_type_scheme(),
1065 Expression::String(_, _) => TypeScheme::make_quantified(Type::String),
1066 Expression::InstantiateStruct { struct_info, .. } => {
1067 TypeScheme::make_quantified(Type::Struct(Box::new(struct_info.clone())))
1068 }
1069 Expression::AccessField { field_type, .. } => field_type.clone(),
1070 Expression::List { type_scheme, .. } => match type_scheme {
1071 TypeScheme::Concrete(t) => TypeScheme::Concrete(Type::List(Box::new(t.clone()))),
1072 TypeScheme::Quantified(ngen, qt) => TypeScheme::Quantified(
1073 *ngen,
1074 crate::typechecker::qualified_type::QualifiedType {
1075 inner: Type::List(Box::new(qt.inner.clone())),
1076 bounds: qt.bounds.clone(),
1077 },
1078 ),
1079 },
1080 Expression::TypedHole(_, type_) => type_.clone(),
1081 }
1082 }
1083}
1084
1085fn accepts_prefix_markup(accepts_prefix: &Option<AcceptsPrefix>) -> Markup {
1086 if let Some(accepts_prefix) = accepts_prefix {
1087 m::operator(":")
1088 + m::space()
1089 + match accepts_prefix {
1090 AcceptsPrefix {
1091 short: true,
1092 long: true,
1093 } => m::keyword("both"),
1094 AcceptsPrefix {
1095 short: true,
1096 long: false,
1097 } => m::keyword("short"),
1098 AcceptsPrefix {
1099 short: false,
1100 long: true,
1101 } => m::keyword("long"),
1102 AcceptsPrefix {
1103 short: false,
1104 long: false,
1105 } => m::keyword("none"),
1106 }
1107 } else {
1108 m::empty()
1109 }
1110}
1111
1112fn decorator_markup(decorators: &Vec<Decorator>) -> Markup {
1113 let mut markup_decorators = m::empty();
1114 for decorator in decorators {
1115 markup_decorators = markup_decorators
1116 + match decorator {
1117 Decorator::MetricPrefixes => m::decorator("@metric_prefixes"),
1118 Decorator::BinaryPrefixes => m::decorator("@binary_prefixes"),
1119 Decorator::Abbreviation => m::decorator("@abbreviation"),
1120 Decorator::Aliases(names) => {
1121 m::decorator("@aliases")
1122 + m::operator("(")
1123 + Itertools::intersperse(
1124 names.iter().map(|(name, accepts_prefix, _)| {
1125 m::unit(name.to_compact_string())
1126 + accepts_prefix_markup(accepts_prefix)
1127 }),
1128 m::operator(", "),
1129 )
1130 .sum()
1131 + m::operator(")")
1132 }
1133 Decorator::Url(url) => {
1134 m::decorator("@url")
1135 + m::operator("(")
1136 + m::string(url.clone())
1137 + m::operator(")")
1138 }
1139 Decorator::Name(name) => {
1140 m::decorator("@name")
1141 + m::operator("(")
1142 + m::string(name.clone())
1143 + m::operator(")")
1144 }
1145 Decorator::Description(description) => {
1146 m::decorator("@description")
1147 + m::operator("(")
1148 + m::string(description.clone())
1149 + m::operator(")")
1150 }
1151 Decorator::Example(example_code, example_description) => {
1152 m::decorator("@example")
1153 + m::operator("(")
1154 + m::string(example_code.clone())
1155 + if let Some(example_description) = example_description {
1156 m::operator(", ") + m::string(example_description.clone())
1157 } else {
1158 m::empty()
1159 }
1160 + m::operator(")")
1161 }
1162 }
1163 + m::nl();
1164 }
1165 markup_decorators
1166}
1167
1168pub fn pretty_print_function_signature<'a>(
1169 function_name: &str,
1170 fn_type: &QualifiedType,
1171 type_parameters: &[TypeVariable],
1172 parameters: impl Iterator<
1173 Item = (
1174 &'a str, Markup, ),
1177 >,
1178 readable_return_type: &Markup,
1179) -> Markup {
1180 let markup_type_parameters = if type_parameters.is_empty() {
1181 m::empty()
1182 } else {
1183 m::operator("<")
1184 + Itertools::intersperse(
1185 type_parameters.iter().map(|tv| {
1186 m::type_identifier(tv.unsafe_name().to_compact_string())
1187 + if fn_type.bounds.is_dtype_bound(tv) {
1188 m::operator(":") + m::space() + m::type_identifier("Dim")
1189 } else {
1190 m::empty()
1191 }
1192 }),
1193 m::operator(", "),
1194 )
1195 .sum()
1196 + m::operator(">")
1197 };
1198
1199 let markup_parameters = Itertools::intersperse(
1200 parameters.map(|(name, parameter_type)| {
1201 m::identifier(name.to_compact_string()) + m::operator(":") + m::space() + parameter_type
1202 }),
1203 m::operator(", "),
1204 )
1205 .sum();
1206
1207 let markup_return_type =
1208 m::space() + m::operator("->") + m::space() + readable_return_type.clone();
1209
1210 m::keyword("fn")
1211 + m::space()
1212 + m::identifier(function_name.to_compact_string())
1213 + markup_type_parameters
1214 + m::operator("(")
1215 + markup_parameters
1216 + m::operator(")")
1217 + markup_return_type
1218}
1219
1220impl PrettyPrint for Statement<'_> {
1221 fn pretty_print(&self) -> Markup {
1222 match self {
1223 Statement::DefineVariable(DefineVariable {
1224 name,
1225 expr,
1226 readable_type,
1227 ..
1228 }) => {
1229 m::keyword("let")
1230 + m::space()
1231 + m::identifier(name.to_compact_string())
1232 + m::operator(":")
1233 + m::space()
1234 + readable_type.clone()
1235 + m::space()
1236 + m::operator("=")
1237 + m::space()
1238 + expr.pretty_print()
1239 }
1240 Statement::DefineFunction {
1241 function_name,
1242 type_parameters,
1243 parameters,
1244 body,
1245 local_variables,
1246 fn_type,
1247 readable_return_type,
1248 ..
1249 } => {
1250 let (fn_type, type_parameters) =
1251 fn_type.instantiate_for_printing(Some(type_parameters.iter().map(|(n, _)| *n)));
1252
1253 let mut pretty_local_variables = None;
1254 let mut first = true;
1255 if !local_variables.is_empty() {
1256 let mut plv = m::empty();
1257 for DefineVariable {
1258 name,
1259 expr,
1260 readable_type,
1261 ..
1262 } in local_variables
1263 {
1264 let introducer_keyword = if first {
1265 first = false;
1266 m::space() + m::space() + m::keyword("where")
1267 } else {
1268 m::space() + m::space() + m::space() + m::space() + m::keyword("and")
1269 };
1270
1271 plv += m::nl()
1272 + introducer_keyword
1273 + m::space()
1274 + m::identifier(name.to_compact_string())
1275 + m::operator(":")
1276 + m::space()
1277 + readable_type.clone()
1278 + m::space()
1279 + m::operator("=")
1280 + m::space()
1281 + expr.pretty_print();
1282 }
1283 pretty_local_variables = Some(plv);
1284 }
1285
1286 pretty_print_function_signature(
1287 function_name,
1288 &fn_type,
1289 &type_parameters,
1290 parameters
1291 .iter()
1292 .map(|(_, name, _, type_)| (*name, type_.clone())),
1293 readable_return_type,
1294 ) + body
1295 .as_ref()
1296 .map(|e| m::space() + m::operator("=") + m::space() + e.pretty_print())
1297 .unwrap_or_default()
1298 + pretty_local_variables.unwrap_or_default()
1299 }
1300 Statement::Expression(expr) => expr.pretty_print(),
1301 Statement::DefineDimension(identifier, dexprs) if dexprs.is_empty() => {
1302 m::keyword("dimension")
1303 + m::space()
1304 + m::type_identifier(identifier.to_compact_string())
1305 }
1306 Statement::DefineDimension(identifier, dexprs) => {
1307 m::keyword("dimension")
1308 + m::space()
1309 + m::type_identifier(identifier.to_compact_string())
1310 + m::space()
1311 + m::operator("=")
1312 + m::space()
1313 + Itertools::intersperse(
1314 dexprs.iter().map(|d| d.pretty_print()),
1315 m::space() + m::operator("=") + m::space(),
1316 )
1317 .sum()
1318 }
1319 Statement::DefineBaseUnit {
1320 name,
1321 decorators,
1322 type_annotation,
1323 type_scheme,
1324 ..
1325 } => {
1326 decorator_markup(decorators)
1327 + m::keyword("unit")
1328 + m::space()
1329 + m::unit(name.to_compact_string())
1330 + m::operator(":")
1331 + m::space()
1332 + type_annotation
1333 .as_ref()
1334 .map(|a: &TypeAnnotation| a.pretty_print())
1335 .unwrap_or(type_scheme.pretty_print())
1336 }
1337 Statement::DefineDerivedUnit {
1338 name,
1339 expr,
1340 decorators,
1341 readable_type,
1342 ..
1343 } => {
1344 decorator_markup(decorators)
1345 + m::keyword("unit")
1346 + m::space()
1347 + m::unit(name.to_compact_string())
1348 + m::operator(":")
1349 + m::space()
1350 + readable_type.clone()
1351 + m::space()
1352 + m::operator("=")
1353 + m::space()
1354 + expr.pretty_print()
1355 }
1356 Statement::ProcedureCall { kind, args, .. } => {
1357 let identifier = match kind {
1358 ProcedureKind::Print => "print",
1359 ProcedureKind::Assert => "assert",
1360 ProcedureKind::AssertEq => "assert_eq",
1361 ProcedureKind::Type => "type",
1362 };
1363 m::identifier(identifier)
1364 + m::operator("(")
1365 + Itertools::intersperse(
1366 args.iter().map(|a| a.pretty_print()),
1367 m::operator(",") + m::space(),
1368 )
1369 .sum()
1370 + m::operator(")")
1371 }
1372 Statement::DefineStruct(StructInfo { name, fields, .. }) => {
1373 m::keyword("struct")
1374 + m::space()
1375 + m::type_identifier(name.clone())
1376 + m::space()
1377 + m::operator("{")
1378 + if fields.is_empty() {
1379 m::empty()
1380 } else {
1381 m::space()
1382 + Itertools::intersperse(
1383 fields.iter().map(|(n, (_, t))| {
1384 m::identifier(n.clone())
1385 + m::operator(":")
1386 + m::space()
1387 + t.pretty_print()
1388 }),
1389 m::operator(",") + m::space(),
1390 )
1391 .sum()
1392 + m::space()
1393 }
1394 + m::operator("}")
1395 }
1396 }
1397 }
1398}
1399
1400fn pretty_scalar(n: Number) -> Markup {
1401 m::value(n.pretty_print())
1402}
1403
1404fn with_parens(expr: &Expression) -> Markup {
1405 match expr {
1406 Expression::Scalar { .. }
1407 | Expression::Identifier { .. }
1408 | Expression::UnitIdentifier { .. }
1409 | Expression::FunctionCall { .. }
1410 | Expression::CallableCall { .. }
1411 | Expression::Boolean(..)
1412 | Expression::String(..)
1413 | Expression::InstantiateStruct { .. }
1414 | Expression::AccessField { .. }
1415 | Expression::List { .. }
1416 | Expression::TypedHole(_, _) => expr.pretty_print(),
1417 Expression::UnaryOperator { .. }
1418 | Expression::BinaryOperator { .. }
1419 | Expression::BinaryOperatorForDate { .. }
1420 | Expression::Condition { .. } => m::operator("(") + expr.pretty_print() + m::operator(")"),
1421 }
1422}
1423
1424fn with_parens_liberal(expr: &Expression) -> Markup {
1426 match expr {
1427 Expression::BinaryOperator {
1428 op: BinaryOperator::Mul,
1429 lhs,
1430 rhs,
1431 ..
1432 } if matches!(**lhs, Expression::Scalar { .. })
1433 && matches!(**rhs, Expression::UnitIdentifier { .. }) =>
1434 {
1435 expr.pretty_print()
1436 }
1437 _ => with_parens(expr),
1438 }
1439}
1440
1441fn pretty_print_binop(op: &BinaryOperator, lhs: &Expression, rhs: &Expression) -> Markup {
1442 match op {
1443 BinaryOperator::ConvertTo => {
1444 lhs.pretty_print() + op.pretty_print() + rhs.pretty_print()
1446 }
1447 BinaryOperator::Mul => match (lhs, rhs) {
1448 (
1449 Expression::Scalar { value: s, .. },
1450 Expression::UnitIdentifier {
1451 prefix, full_name, ..
1452 },
1453 ) => {
1454 pretty_scalar(*s)
1456 + m::space()
1457 + m::unit(format_compact!("{}{}", prefix.as_string_long(), full_name))
1458 }
1459 (Expression::Scalar { value: s, .. }, Expression::Identifier { name, .. }) => {
1460 pretty_scalar(*s) + m::space() + m::identifier(name.to_compact_string())
1462 }
1463 _ => {
1464 let add_parens_if_needed = |expr: &Expression| {
1465 if matches!(
1466 expr,
1467 Expression::BinaryOperator {
1468 op: BinaryOperator::Power,
1469 ..
1470 } | Expression::BinaryOperator {
1471 op: BinaryOperator::Mul,
1472 ..
1473 }
1474 ) {
1475 expr.pretty_print()
1476 } else {
1477 with_parens_liberal(expr)
1478 }
1479 };
1480
1481 add_parens_if_needed(lhs) + op.pretty_print() + add_parens_if_needed(rhs)
1482 }
1483 },
1484 BinaryOperator::Div => {
1485 let lhs_add_parens_if_needed = |expr: &Expression| {
1486 if matches!(
1487 expr,
1488 Expression::BinaryOperator {
1489 op: BinaryOperator::Power,
1490 ..
1491 } | Expression::BinaryOperator {
1492 op: BinaryOperator::Mul,
1493 ..
1494 }
1495 ) {
1496 expr.pretty_print()
1497 } else {
1498 with_parens_liberal(expr)
1499 }
1500 };
1501 let rhs_add_parens_if_needed = |expr: &Expression| {
1502 if matches!(
1503 expr,
1504 Expression::BinaryOperator {
1505 op: BinaryOperator::Power,
1506 ..
1507 }
1508 ) {
1509 expr.pretty_print()
1510 } else {
1511 with_parens_liberal(expr)
1512 }
1513 };
1514
1515 lhs_add_parens_if_needed(lhs) + op.pretty_print() + rhs_add_parens_if_needed(rhs)
1516 }
1517 BinaryOperator::Add => {
1518 let add_parens_if_needed = |expr: &Expression| {
1519 if matches!(
1520 expr,
1521 Expression::BinaryOperator {
1522 op: BinaryOperator::Power,
1523 ..
1524 } | Expression::BinaryOperator {
1525 op: BinaryOperator::Mul,
1526 ..
1527 } | Expression::BinaryOperator {
1528 op: BinaryOperator::Add,
1529 ..
1530 }
1531 ) {
1532 expr.pretty_print()
1533 } else {
1534 with_parens_liberal(expr)
1535 }
1536 };
1537
1538 add_parens_if_needed(lhs) + op.pretty_print() + add_parens_if_needed(rhs)
1539 }
1540 BinaryOperator::Sub => {
1541 let add_parens_if_needed = |expr: &Expression| {
1542 if matches!(
1543 expr,
1544 Expression::BinaryOperator {
1545 op: BinaryOperator::Power,
1546 ..
1547 } | Expression::BinaryOperator {
1548 op: BinaryOperator::Mul,
1549 ..
1550 }
1551 ) {
1552 expr.pretty_print()
1553 } else {
1554 with_parens_liberal(expr)
1555 }
1556 };
1557
1558 add_parens_if_needed(lhs) + op.pretty_print() + add_parens_if_needed(rhs)
1559 }
1560 BinaryOperator::Power if matches!(rhs, Expression::Scalar { value, .. } if value.to_f64() == 2.0) => {
1561 with_parens(lhs) + m::operator("²")
1562 }
1563 BinaryOperator::Power if matches!(rhs, Expression::Scalar { value, .. } if value.to_f64() == 3.0) => {
1564 with_parens(lhs) + m::operator("³")
1565 }
1566 _ => with_parens(lhs) + op.pretty_print() + with_parens(rhs),
1567 }
1568}
1569
1570impl PrettyPrint for Expression<'_> {
1571 fn pretty_print(&self) -> Markup {
1572 use Expression::*;
1573
1574 match self {
1575 Scalar { value, .. } => pretty_scalar(*value),
1576 Identifier { name, .. } => m::identifier(name.to_compact_string()),
1577 UnitIdentifier {
1578 prefix, full_name, ..
1579 } => m::unit(format_compact!("{}{}", prefix.as_string_long(), full_name)),
1580 UnaryOperator {
1581 op: self::UnaryOperator::Negate,
1582 expr,
1583 ..
1584 } => m::operator("-") + with_parens(expr),
1585 UnaryOperator {
1586 op: self::UnaryOperator::Factorial(order),
1587 expr,
1588 ..
1589 } => with_parens(expr) + (0..order.get()).map(|_| m::operator("!")).sum(),
1590 UnaryOperator {
1591 op: self::UnaryOperator::LogicalNeg,
1592 expr,
1593 ..
1594 } => m::operator("!") + with_parens(expr),
1595 BinaryOperator { op, lhs, rhs, .. } => pretty_print_binop(op, lhs, rhs),
1596 BinaryOperatorForDate { op, lhs, rhs, .. } => pretty_print_binop(op, lhs, rhs),
1597 FunctionCall { name, args, .. } => {
1598 if args.len() == 1 {
1600 if *name == "from_celsius" {
1602 return with_parens_liberal(&args[0]) + m::space() + m::unit("°C");
1603 } else if *name == "from_fahrenheit" {
1604 return with_parens_liberal(&args[0]) + m::space() + m::unit("°F");
1605 }
1606 if *name == "°C" || *name == "celsius" || *name == "degree_celsius" {
1609 return with_parens_liberal(&args[0])
1610 + m::space()
1611 + m::operator("->")
1612 + m::space()
1613 + m::unit("°C");
1614 } else if *name == "°F" || *name == "fahrenheit" || *name == "degree_fahrenheit"
1615 {
1616 return with_parens_liberal(&args[0])
1617 + m::space()
1618 + m::operator("->")
1619 + m::space()
1620 + m::unit("°F");
1621 }
1622 }
1623
1624 m::identifier(name.to_compact_string())
1625 + m::operator("(")
1626 + itertools::Itertools::intersperse(
1627 args.iter().map(|e: &Expression| e.pretty_print()),
1628 m::operator(",") + m::space(),
1629 )
1630 .sum()
1631 + m::operator(")")
1632 }
1633 CallableCall {
1634 callable: expr,
1635 args,
1636 ..
1637 } => {
1638 if args.len() == 1
1640 && let Expression::Identifier { name, .. } = expr.as_ref()
1641 {
1642 if *name == "°C" || *name == "celsius" || *name == "degree_celsius" {
1643 return with_parens_liberal(&args[0])
1644 + m::space()
1645 + m::operator("->")
1646 + m::space()
1647 + m::unit("°C");
1648 } else if *name == "°F" || *name == "fahrenheit" || *name == "degree_fahrenheit"
1649 {
1650 return with_parens_liberal(&args[0])
1651 + m::space()
1652 + m::operator("->")
1653 + m::space()
1654 + m::unit("°F");
1655 }
1656 }
1657
1658 expr.pretty_print()
1659 + m::operator("(")
1660 + itertools::Itertools::intersperse(
1661 args.iter().map(|e: &Expression| e.pretty_print()),
1662 m::operator(",") + m::space(),
1663 )
1664 .sum()
1665 + m::operator(")")
1666 }
1667 Boolean(_, val) => val.pretty_print(),
1668 String(_, parts) => parts.pretty_print(),
1669 Condition {
1670 condition,
1671 then_expr,
1672 else_expr,
1673 ..
1674 } => {
1675 m::keyword("if")
1676 + m::space()
1677 + with_parens(condition)
1678 + m::space()
1679 + m::keyword("then")
1680 + m::space()
1681 + with_parens(then_expr)
1682 + m::space()
1683 + m::keyword("else")
1684 + m::space()
1685 + with_parens(else_expr)
1686 }
1687 InstantiateStruct {
1688 fields,
1689 struct_info,
1690 ..
1691 } => {
1692 m::type_identifier(struct_info.name.clone())
1693 + m::space()
1694 + m::operator("{")
1695 + if fields.is_empty() {
1696 m::empty()
1697 } else {
1698 m::space()
1699 + itertools::Itertools::intersperse(
1700 fields.iter().map(|(n, e)| {
1701 m::identifier(n.to_compact_string())
1702 + m::operator(":")
1703 + m::space()
1704 + e.pretty_print()
1705 }),
1706 m::operator(",") + m::space(),
1707 )
1708 .sum()
1709 + m::space()
1710 }
1711 + m::operator("}")
1712 }
1713 AccessField {
1714 expr, field_name, ..
1715 } => {
1716 expr.pretty_print()
1717 + m::operator(".")
1718 + m::identifier(field_name.to_compact_string())
1719 }
1720 List { elements, .. } => {
1721 m::operator("[")
1722 + itertools::Itertools::intersperse(
1723 elements.iter().map(|e| e.pretty_print()),
1724 m::operator(",") + m::space(),
1725 )
1726 .sum()
1727 + m::operator("]")
1728 }
1729 TypedHole(_, _) => m::operator("?"),
1730 }
1731 }
1732}
1733
1734#[cfg(test)]
1735mod tests {
1736 use super::*;
1737 use crate::ast::ReplaceSpans;
1738 use crate::markup::{Formatter, PlainTextFormatter};
1739 use crate::prefix_transformer::Transformer;
1740
1741 fn parse(code: &str) -> Statement<'_> {
1742 let statements = crate::parser::parse(
1743 "dimension Scalar = 1
1744 dimension Length
1745 dimension Time
1746 dimension Mass
1747
1748 fn sin(x: Scalar) -> Scalar
1749 fn cos(x: Scalar) -> Scalar
1750 fn asin(x: Scalar) -> Scalar
1751 fn atan(x: Scalar) -> Scalar
1752 fn atan2<T>(x: T, y: T) -> Scalar
1753 fn sqrt(x) = x^(1/2)
1754 let pi = 2 asin(1)
1755
1756 @aliases(m: short)
1757 @metric_prefixes
1758 unit meter: Length
1759
1760 @aliases(s: short)
1761 @metric_prefixes
1762 unit second: Time
1763
1764 @aliases(g: short)
1765 @metric_prefixes
1766 unit gram: Mass
1767
1768 @aliases(rad: short)
1769 @metric_prefixes
1770 unit radian: Scalar = meter / meter
1771
1772 @aliases(°: none)
1773 unit degree = 180/pi × radian
1774
1775 @aliases(in: short)
1776 unit inch = 0.0254 m
1777
1778 @metric_prefixes
1779 unit points
1780
1781 struct Foo {foo: Length, bar: Time}
1782
1783 let a = 1
1784 let b = 1
1785 let c = 1
1786 let d = 1
1787 let e = 1
1788 let f = 1
1789 let x = 1
1790 let r = 2 m
1791 let vol = 3 m^3
1792 let density = 1000 kg / m^3
1793 let länge = 1
1794 let x_2 = 1
1795 let µ = 1
1796 let _prefixed = 1",
1797 0,
1798 )
1799 .unwrap()
1800 .into_iter()
1801 .chain(crate::parser::parse(code, 0).unwrap());
1802
1803 let mut transformer = Transformer::new();
1804 let transformed_statements = transformer.transform(statements).unwrap().replace_spans();
1805
1806 crate::typechecker::TypeChecker::default()
1807 .check(&transformed_statements)
1808 .unwrap()
1809 .last()
1810 .unwrap()
1811 .clone()
1812 }
1813
1814 fn pretty_print(stmt: &Statement) -> CompactString {
1815 let markup = stmt.pretty_print();
1816
1817 (PlainTextFormatter {}).format(&markup, false)
1818 }
1819
1820 fn equal_pretty(input: &str, expected: &str) {
1821 println!();
1822 println!("expected: '{expected}'");
1823 let actual = pretty_print(&parse(input));
1824 println!("actual: '{actual}'");
1825 assert_eq!(actual, expected);
1826 }
1827
1828 #[test]
1829 fn pretty_print_basic() {
1830 equal_pretty("2+3", "2 + 3");
1831 equal_pretty("2*3", "2 × 3");
1832 equal_pretty("2^3", "2³");
1833 equal_pretty("2km", "2 kilometer");
1834 equal_pretty("2kilometer", "2 kilometer");
1835 equal_pretty("sin(30°)", "sin(30 degree)");
1836 equal_pretty("2*3*4", "2 × 3 × 4");
1837 equal_pretty("2*(3*4)", "2 × 3 × 4");
1838 equal_pretty("2+3+4", "2 + 3 + 4");
1839 equal_pretty("2+(3+4)", "2 + 3 + 4");
1840 equal_pretty("atan(30cm / 2m)", "atan(30 centimeter / 2 meter)");
1841 equal_pretty("1mrad -> °", "1 milliradian ➞ degree");
1842 equal_pretty("2km+2cm -> in", "2 kilometer + 2 centimeter ➞ inch");
1843 equal_pretty("2^3 + 4^5", "2³ + 4^5");
1844 equal_pretty("2^3 - 4^5", "2³ - 4^5");
1845 equal_pretty("2^3 * 4^5", "2³ × 4^5");
1846 equal_pretty("2 * 3 + 4 * 5", "2 × 3 + 4 × 5");
1847 equal_pretty("2 * 3 / 4", "2 × 3 / 4");
1848 equal_pretty("123.123 km² / s²", "123.123 × kilometer² / second²");
1849 }
1850
1851 fn roundtrip_check(code: &str) {
1852 println!("Roundtrip check for code = '{code}'");
1853 let ast1 = parse(code);
1854 let code_pretty = pretty_print(&ast1);
1855 println!(" pretty printed code = '{code_pretty}'");
1856 let ast2 = parse(&code_pretty);
1857 assert_eq!(ast1, ast2);
1858 }
1859
1860 #[test]
1861 fn pretty_print_roundtrip_check() {
1862 roundtrip_check("1.0");
1863 roundtrip_check("2");
1864 roundtrip_check("1 + 2");
1865
1866 roundtrip_check("-2.3e-12387");
1867 roundtrip_check("2.3e-12387");
1868 roundtrip_check("18379173");
1869 roundtrip_check("2+3");
1870 roundtrip_check("2+3*5");
1871 roundtrip_check("-3^4+2/(4+2*3)");
1872 roundtrip_check("1-2-3-4-(5-6-7)");
1873 roundtrip_check("1/2/3/4/(5/6/7)");
1874 roundtrip_check("kilogram");
1875 roundtrip_check("2meter/second");
1876 roundtrip_check("a+b*c^d-e*f");
1877 roundtrip_check("sin(x)^3");
1878 roundtrip_check("sin(cos(atan(x)+2))^3");
1879 roundtrip_check("2^3^4^5");
1880 roundtrip_check("(2^3)^(4^5)");
1881 roundtrip_check("sqrt(1.4^2 + 1.5^2) * cos(pi/3)^2");
1882 roundtrip_check("40 kilometer * 9.8meter/second^2 * 150centimeter");
1883 roundtrip_check("4/3 * pi * r³");
1884 roundtrip_check("vol * density -> kilogram");
1885 roundtrip_check("atan(30 centimeter / 2 meter)");
1886 roundtrip_check("500kilometer/second -> centimeter/second");
1887 roundtrip_check("länge * x_2 * µ * _prefixed");
1888 roundtrip_check("2meter^3");
1889 roundtrip_check("(2meter)^3");
1890 roundtrip_check("-sqrt(-30meter^3)");
1891 roundtrip_check("-3^4");
1892 roundtrip_check("(-3)^4");
1893 roundtrip_check("atan2(2,3)");
1894 roundtrip_check("2^3!");
1895 roundtrip_check("-3!");
1896 roundtrip_check("(-3)!");
1897 roundtrip_check("megapoints");
1898 roundtrip_check("Foo { foo: 1 meter, bar: 1 second }");
1899 roundtrip_check("\"foo\"");
1900 roundtrip_check("\"newline: \\n\"");
1901 }
1902
1903 #[test]
1904 fn pretty_print_dexpr() {
1905 roundtrip_check("unit z: Length");
1906 roundtrip_check("unit z: Length * Time");
1907 roundtrip_check("unit z: Length * Time^2");
1908 roundtrip_check("unit z: Length^-3 * Time^2");
1909 roundtrip_check("unit z: Length / Time");
1910 roundtrip_check("unit z: Length / Time^2");
1911 roundtrip_check("unit z: Length / Time^(-2)");
1912 roundtrip_check("unit z: Length / (Time * Mass)");
1913 roundtrip_check("unit z: Length^5 * Time^4 / (Time^2 * Mass^3)");
1914 }
1915}