1#[cfg(feature = "tolerant-ast")]
18use {
19 super::expr_allows_errors::AstExprErrorKind, crate::parser::err::ToASTError,
20 crate::parser::err::ToASTErrorKind,
21};
22
23use crate::{
24 ast::*,
25 expr_builder::{self, ExprBuilder as _},
26 extensions::Extensions,
27 parser::{err::ParseErrors, Loc},
28};
29use educe::Educe;
30use miette::Diagnostic;
31use serde::{Deserialize, Serialize};
32use smol_str::SmolStr;
33use std::{
34 borrow::Cow,
35 collections::{btree_map, BTreeMap, HashMap},
36 hash::{Hash, Hasher},
37 mem,
38 sync::Arc,
39};
40use thiserror::Error;
41
42#[cfg(feature = "wasm")]
43extern crate tsify;
44
45#[derive(Educe, Debug, Clone)]
52#[educe(PartialEq, Eq, Hash)]
53pub struct Expr<T = ()> {
54 expr_kind: ExprKind<T>,
55 #[educe(PartialEq(ignore))]
56 #[educe(Hash(ignore))]
57 source_loc: Option<Loc>,
58 data: T,
59}
60
61#[derive(Hash, Debug, Clone, PartialEq, Eq)]
64pub enum ExprKind<T = ()> {
65 Lit(Literal),
67 Var(Var),
69 Slot(SlotId),
71 Unknown(Unknown),
73 If {
75 test_expr: Arc<Expr<T>>,
77 then_expr: Arc<Expr<T>>,
79 else_expr: Arc<Expr<T>>,
81 },
82 And {
84 left: Arc<Expr<T>>,
86 right: Arc<Expr<T>>,
88 },
89 Or {
91 left: Arc<Expr<T>>,
93 right: Arc<Expr<T>>,
95 },
96 UnaryApp {
98 op: UnaryOp,
100 arg: Arc<Expr<T>>,
102 },
103 BinaryApp {
105 op: BinaryOp,
107 arg1: Arc<Expr<T>>,
109 arg2: Arc<Expr<T>>,
111 },
112 ExtensionFunctionApp {
118 fn_name: Name,
120 args: Arc<Vec<Expr<T>>>,
122 },
123 GetAttr {
125 expr: Arc<Expr<T>>,
128 attr: SmolStr,
130 },
131 HasAttr {
133 expr: Arc<Expr<T>>,
135 attr: SmolStr,
137 },
138 Like {
140 expr: Arc<Expr<T>>,
142 pattern: Pattern,
146 },
147 Is {
150 expr: Arc<Expr<T>>,
152 entity_type: EntityType,
154 },
155 Set(Arc<Vec<Expr<T>>>),
162 Record(Arc<BTreeMap<SmolStr, Expr<T>>>),
164 #[cfg(feature = "tolerant-ast")]
165 Error {
167 error_kind: AstExprErrorKind,
169 },
170}
171
172impl<T> ExprKind<T> {
173 fn variant_order(&self) -> u8 {
175 match self {
176 ExprKind::Lit(_) => 0,
177 ExprKind::Var(_) => 1,
178 ExprKind::Slot(_) => 2,
179 ExprKind::Unknown(_) => 3,
180 ExprKind::If { .. } => 4,
181 ExprKind::And { .. } => 5,
182 ExprKind::Or { .. } => 6,
183 ExprKind::UnaryApp { .. } => 7,
184 ExprKind::BinaryApp { .. } => 8,
185 ExprKind::ExtensionFunctionApp { .. } => 9,
186 ExprKind::GetAttr { .. } => 10,
187 ExprKind::HasAttr { .. } => 11,
188 ExprKind::Like { .. } => 12,
189 ExprKind::Set(_) => 13,
190 ExprKind::Record(_) => 14,
191 ExprKind::Is { .. } => 15,
192 #[cfg(feature = "tolerant-ast")]
193 ExprKind::Error { .. } => 16,
194 }
195 }
196}
197
198impl From<Value> for Expr {
199 fn from(v: Value) -> Self {
200 Expr::from(v.value).with_maybe_source_loc(v.loc)
201 }
202}
203
204impl From<ValueKind> for Expr {
205 fn from(v: ValueKind) -> Self {
206 match v {
207 ValueKind::Lit(lit) => Expr::val(lit),
208 ValueKind::Set(set) => Expr::set(set.iter().map(|v| Expr::from(v.clone()))),
209 #[allow(clippy::expect_used)]
211 ValueKind::Record(record) => Expr::record(
212 Arc::unwrap_or_clone(record)
213 .into_iter()
214 .map(|(k, v)| (k, Expr::from(v))),
215 )
216 .expect("cannot have duplicate key because the input was already a BTreeMap"),
217 ValueKind::ExtensionValue(ev) => RestrictedExpr::from(ev.as_ref().clone()).into(),
218 }
219 }
220}
221
222impl From<PartialValue> for Expr {
223 fn from(pv: PartialValue) -> Self {
224 match pv {
225 PartialValue::Value(v) => Expr::from(v),
226 PartialValue::Residual(expr) => expr,
227 }
228 }
229}
230
231impl<T> Expr<T> {
232 pub(crate) fn new(expr_kind: ExprKind<T>, source_loc: Option<Loc>, data: T) -> Self {
233 Self {
234 expr_kind,
235 source_loc,
236 data,
237 }
238 }
239
240 pub fn expr_kind(&self) -> &ExprKind<T> {
244 &self.expr_kind
245 }
246
247 pub fn into_expr_kind(self) -> ExprKind<T> {
249 self.expr_kind
250 }
251
252 pub fn data(&self) -> &T {
254 &self.data
255 }
256
257 pub fn into_data(self) -> T {
260 self.data
261 }
262
263 pub fn into_parts(self) -> (ExprKind<T>, Option<Loc>, T) {
266 (self.expr_kind, self.source_loc, self.data)
267 }
268
269 pub fn source_loc(&self) -> Option<&Loc> {
271 self.source_loc.as_ref()
272 }
273
274 pub fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
276 Self { source_loc, ..self }
277 }
278
279 pub fn set_data(&mut self, data: T) {
282 self.data = data;
283 }
284
285 pub fn is_ref(&self) -> bool {
290 match &self.expr_kind {
291 ExprKind::Lit(lit) => lit.is_ref(),
292 _ => false,
293 }
294 }
295
296 pub fn is_slot(&self) -> bool {
298 matches!(&self.expr_kind, ExprKind::Slot(_))
299 }
300
301 pub fn is_ref_set(&self) -> bool {
306 match &self.expr_kind {
307 ExprKind::Set(exprs) => exprs.iter().all(|e| e.is_ref()),
308 _ => false,
309 }
310 }
311
312 pub fn subexpressions(&self) -> impl Iterator<Item = &Self> {
314 expr_iterator::ExprIterator::new(self)
315 }
316
317 pub fn slots(&self) -> impl Iterator<Item = Slot> + '_ {
319 self.subexpressions()
320 .filter_map(|exp| match &exp.expr_kind {
321 ExprKind::Slot(slotid) => Some(Slot {
322 id: *slotid,
323 loc: exp.source_loc().cloned(),
324 }),
325 _ => None,
326 })
327 }
328
329 pub fn is_projectable(&self) -> bool {
333 self.subexpressions().all(|e| {
334 matches!(
335 e.expr_kind(),
336 ExprKind::Lit(_)
337 | ExprKind::Unknown(_)
338 | ExprKind::Set(_)
339 | ExprKind::Var(_)
340 | ExprKind::Record(_)
341 )
342 })
343 }
344
345 pub fn try_type_of(&self, extensions: &Extensions<'_>) -> Option<Type> {
357 match &self.expr_kind {
358 ExprKind::Lit(l) => Some(l.type_of()),
359 ExprKind::Var(_) => None,
360 ExprKind::Slot(_) => None,
361 ExprKind::Unknown(u) => u.type_annotation.clone(),
362 ExprKind::If {
363 then_expr,
364 else_expr,
365 ..
366 } => {
367 let type_of_then = then_expr.try_type_of(extensions);
368 let type_of_else = else_expr.try_type_of(extensions);
369 if type_of_then == type_of_else {
370 type_of_then
371 } else {
372 None
373 }
374 }
375 ExprKind::And { .. } => Some(Type::Bool),
376 ExprKind::Or { .. } => Some(Type::Bool),
377 ExprKind::UnaryApp {
378 op: UnaryOp::Neg, ..
379 } => Some(Type::Long),
380 ExprKind::UnaryApp {
381 op: UnaryOp::Not, ..
382 } => Some(Type::Bool),
383 ExprKind::UnaryApp {
384 op: UnaryOp::IsEmpty,
385 ..
386 } => Some(Type::Bool),
387 ExprKind::BinaryApp {
388 op: BinaryOp::Add | BinaryOp::Mul | BinaryOp::Sub,
389 ..
390 } => Some(Type::Long),
391 ExprKind::BinaryApp {
392 op:
393 BinaryOp::Contains
394 | BinaryOp::ContainsAll
395 | BinaryOp::ContainsAny
396 | BinaryOp::Eq
397 | BinaryOp::In
398 | BinaryOp::Less
399 | BinaryOp::LessEq,
400 ..
401 } => Some(Type::Bool),
402 ExprKind::BinaryApp {
403 op: BinaryOp::HasTag,
404 ..
405 } => Some(Type::Bool),
406 ExprKind::ExtensionFunctionApp { fn_name, .. } => extensions
407 .func(fn_name)
408 .ok()?
409 .return_type()
410 .map(|rty| rty.clone().into()),
411 ExprKind::GetAttr { .. } => None,
416 ExprKind::BinaryApp {
418 op: BinaryOp::GetTag,
419 ..
420 } => None,
421 ExprKind::HasAttr { .. } => Some(Type::Bool),
422 ExprKind::Like { .. } => Some(Type::Bool),
423 ExprKind::Is { .. } => Some(Type::Bool),
424 ExprKind::Set(_) => Some(Type::Set),
425 ExprKind::Record(_) => Some(Type::Record),
426 #[cfg(feature = "tolerant-ast")]
427 ExprKind::Error { .. } => None,
428 }
429 }
430
431 pub fn into_expr<B: expr_builder::ExprBuilder>(self) -> B::Expr
436 where
437 T: Clone,
438 {
439 let builder = B::new().with_maybe_source_loc(self.source_loc());
440 match self.into_expr_kind() {
441 ExprKind::Lit(lit) => builder.val(lit),
442 ExprKind::Var(var) => builder.var(var),
443 ExprKind::Slot(slot) => builder.slot(slot),
444 ExprKind::Unknown(u) => builder.unknown(u),
445 ExprKind::If {
446 test_expr,
447 then_expr,
448 else_expr,
449 } => builder.ite(
450 Arc::unwrap_or_clone(test_expr).into_expr::<B>(),
451 Arc::unwrap_or_clone(then_expr).into_expr::<B>(),
452 Arc::unwrap_or_clone(else_expr).into_expr::<B>(),
453 ),
454 ExprKind::And { left, right } => builder.and(
455 Arc::unwrap_or_clone(left).into_expr::<B>(),
456 Arc::unwrap_or_clone(right).into_expr::<B>(),
457 ),
458 ExprKind::Or { left, right } => builder.or(
459 Arc::unwrap_or_clone(left).into_expr::<B>(),
460 Arc::unwrap_or_clone(right).into_expr::<B>(),
461 ),
462 ExprKind::UnaryApp { op, arg } => {
463 let arg = Arc::unwrap_or_clone(arg).into_expr::<B>();
464 builder.unary_app(op, arg)
465 }
466 ExprKind::BinaryApp { op, arg1, arg2 } => {
467 let arg1 = Arc::unwrap_or_clone(arg1).into_expr::<B>();
468 let arg2 = Arc::unwrap_or_clone(arg2).into_expr::<B>();
469 builder.binary_app(op, arg1, arg2)
470 }
471 ExprKind::ExtensionFunctionApp { fn_name, args } => {
472 let args = Arc::unwrap_or_clone(args)
473 .into_iter()
474 .map(|e| e.into_expr::<B>());
475 builder.call_extension_fn(fn_name, args)
476 }
477 ExprKind::GetAttr { expr, attr } => {
478 builder.get_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
479 }
480 ExprKind::HasAttr { expr, attr } => {
481 builder.has_attr(Arc::unwrap_or_clone(expr).into_expr::<B>(), attr)
482 }
483 ExprKind::Like { expr, pattern } => {
484 builder.like(Arc::unwrap_or_clone(expr).into_expr::<B>(), pattern)
485 }
486 ExprKind::Is { expr, entity_type } => {
487 builder.is_entity_type(Arc::unwrap_or_clone(expr).into_expr::<B>(), entity_type)
488 }
489 ExprKind::Set(set) => builder.set(
490 Arc::unwrap_or_clone(set)
491 .into_iter()
492 .map(|e| e.into_expr::<B>()),
493 ),
494 #[allow(clippy::unwrap_used)]
496 ExprKind::Record(map) => builder
497 .record(
498 Arc::unwrap_or_clone(map)
499 .into_iter()
500 .map(|(k, v)| (k, v.into_expr::<B>())),
501 )
502 .unwrap(),
503 #[cfg(feature = "tolerant-ast")]
504 #[allow(clippy::unwrap_used)]
506 ExprKind::Error { .. } => builder
507 .error(ParseErrors::singleton(ToASTError::new(
508 ToASTErrorKind::ASTErrorNode,
509 Some(Loc::new(0..1, "AST_ERROR_NODE".into())),
510 )))
511 .unwrap(),
512 }
513 }
514}
515
516#[allow(dead_code)] #[allow(clippy::should_implement_trait)] impl Expr {
519 pub fn val(v: impl Into<Literal>) -> Self {
523 ExprBuilder::new().val(v)
524 }
525
526 pub fn unknown(u: Unknown) -> Self {
528 ExprBuilder::new().unknown(u)
529 }
530
531 pub fn var(v: Var) -> Self {
533 ExprBuilder::new().var(v)
534 }
535
536 pub fn slot(s: SlotId) -> Self {
538 ExprBuilder::new().slot(s)
539 }
540
541 pub fn ite(test_expr: Expr, then_expr: Expr, else_expr: Expr) -> Self {
545 ExprBuilder::new().ite(test_expr, then_expr, else_expr)
546 }
547
548 pub fn ite_arc(test_expr: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Self {
552 ExprBuilder::new().ite_arc(test_expr, then_expr, else_expr)
553 }
554
555 pub fn not(e: Expr) -> Self {
557 ExprBuilder::new().not(e)
558 }
559
560 pub fn is_eq(e1: Expr, e2: Expr) -> Self {
562 ExprBuilder::new().is_eq(e1, e2)
563 }
564
565 pub fn noteq(e1: Expr, e2: Expr) -> Self {
567 ExprBuilder::new().noteq(e1, e2)
568 }
569
570 pub fn and(e1: Expr, e2: Expr) -> Self {
572 ExprBuilder::new().and(e1, e2)
573 }
574
575 pub fn or(e1: Expr, e2: Expr) -> Self {
577 ExprBuilder::new().or(e1, e2)
578 }
579
580 pub fn less(e1: Expr, e2: Expr) -> Self {
582 ExprBuilder::new().less(e1, e2)
583 }
584
585 pub fn lesseq(e1: Expr, e2: Expr) -> Self {
587 ExprBuilder::new().lesseq(e1, e2)
588 }
589
590 pub fn greater(e1: Expr, e2: Expr) -> Self {
592 ExprBuilder::new().greater(e1, e2)
593 }
594
595 pub fn greatereq(e1: Expr, e2: Expr) -> Self {
597 ExprBuilder::new().greatereq(e1, e2)
598 }
599
600 pub fn add(e1: Expr, e2: Expr) -> Self {
602 ExprBuilder::new().add(e1, e2)
603 }
604
605 pub fn sub(e1: Expr, e2: Expr) -> Self {
607 ExprBuilder::new().sub(e1, e2)
608 }
609
610 pub fn mul(e1: Expr, e2: Expr) -> Self {
612 ExprBuilder::new().mul(e1, e2)
613 }
614
615 pub fn neg(e: Expr) -> Self {
617 ExprBuilder::new().neg(e)
618 }
619
620 pub fn is_in(e1: Expr, e2: Expr) -> Self {
624 ExprBuilder::new().is_in(e1, e2)
625 }
626
627 pub fn contains(e1: Expr, e2: Expr) -> Self {
630 ExprBuilder::new().contains(e1, e2)
631 }
632
633 pub fn contains_all(e1: Expr, e2: Expr) -> Self {
635 ExprBuilder::new().contains_all(e1, e2)
636 }
637
638 pub fn contains_any(e1: Expr, e2: Expr) -> Self {
640 ExprBuilder::new().contains_any(e1, e2)
641 }
642
643 pub fn is_empty(e: Expr) -> Self {
645 ExprBuilder::new().is_empty(e)
646 }
647
648 pub fn get_tag(expr: Expr, tag: Expr) -> Self {
651 ExprBuilder::new().get_tag(expr, tag)
652 }
653
654 pub fn has_tag(expr: Expr, tag: Expr) -> Self {
657 ExprBuilder::new().has_tag(expr, tag)
658 }
659
660 pub fn set(exprs: impl IntoIterator<Item = Expr>) -> Self {
662 ExprBuilder::new().set(exprs)
663 }
664
665 pub fn record(
667 pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
668 ) -> Result<Self, ExpressionConstructionError> {
669 ExprBuilder::new().record(pairs)
670 }
671
672 pub fn record_arc(map: Arc<BTreeMap<SmolStr, Expr>>) -> Self {
680 ExprBuilder::new().record_arc(map)
681 }
682
683 pub fn call_extension_fn(fn_name: Name, args: Vec<Expr>) -> Self {
686 ExprBuilder::new().call_extension_fn(fn_name, args)
687 }
688
689 pub fn unary_app(op: impl Into<UnaryOp>, arg: Expr) -> Self {
692 ExprBuilder::new().unary_app(op, arg)
693 }
694
695 pub fn binary_app(op: impl Into<BinaryOp>, arg1: Expr, arg2: Expr) -> Self {
698 ExprBuilder::new().binary_app(op, arg1, arg2)
699 }
700
701 pub fn get_attr(expr: Expr, attr: SmolStr) -> Self {
705 ExprBuilder::new().get_attr(expr, attr)
706 }
707
708 pub fn has_attr(expr: Expr, attr: SmolStr) -> Self {
713 ExprBuilder::new().has_attr(expr, attr)
714 }
715
716 pub fn like(expr: Expr, pattern: Pattern) -> Self {
720 ExprBuilder::new().like(expr, pattern)
721 }
722
723 pub fn is_entity_type(expr: Expr, entity_type: EntityType) -> Self {
725 ExprBuilder::new().is_entity_type(expr, entity_type)
726 }
727
728 pub fn contains_unknown(&self) -> bool {
730 self.subexpressions()
731 .any(|e| matches!(e.expr_kind(), ExprKind::Unknown(_)))
732 }
733
734 pub fn unknowns(&self) -> impl Iterator<Item = &Unknown> {
736 self.subexpressions()
737 .filter_map(|subexpr| match subexpr.expr_kind() {
738 ExprKind::Unknown(u) => Some(u),
739 _ => None,
740 })
741 }
742
743 pub fn substitute(&self, definitions: &HashMap<SmolStr, Value>) -> Expr {
752 match self.substitute_general::<UntypedSubstitution>(definitions) {
753 Ok(e) => e,
754 Err(empty) => match empty {},
755 }
756 }
757
758 pub fn substitute_typed(
767 &self,
768 definitions: &HashMap<SmolStr, Value>,
769 ) -> Result<Expr, SubstitutionError> {
770 self.substitute_general::<TypedSubstitution>(definitions)
771 }
772
773 fn substitute_general<T: SubstitutionFunction>(
777 &self,
778 definitions: &HashMap<SmolStr, Value>,
779 ) -> Result<Expr, T::Err> {
780 match self.expr_kind() {
781 ExprKind::Lit(_) => Ok(self.clone()),
782 ExprKind::Unknown(u @ Unknown { name, .. }) => T::substitute(u, definitions.get(name)),
783 ExprKind::Var(_) => Ok(self.clone()),
784 ExprKind::Slot(_) => Ok(self.clone()),
785 ExprKind::If {
786 test_expr,
787 then_expr,
788 else_expr,
789 } => Ok(Expr::ite(
790 test_expr.substitute_general::<T>(definitions)?,
791 then_expr.substitute_general::<T>(definitions)?,
792 else_expr.substitute_general::<T>(definitions)?,
793 )),
794 ExprKind::And { left, right } => Ok(Expr::and(
795 left.substitute_general::<T>(definitions)?,
796 right.substitute_general::<T>(definitions)?,
797 )),
798 ExprKind::Or { left, right } => Ok(Expr::or(
799 left.substitute_general::<T>(definitions)?,
800 right.substitute_general::<T>(definitions)?,
801 )),
802 ExprKind::UnaryApp { op, arg } => Ok(Expr::unary_app(
803 *op,
804 arg.substitute_general::<T>(definitions)?,
805 )),
806 ExprKind::BinaryApp { op, arg1, arg2 } => Ok(Expr::binary_app(
807 *op,
808 arg1.substitute_general::<T>(definitions)?,
809 arg2.substitute_general::<T>(definitions)?,
810 )),
811 ExprKind::ExtensionFunctionApp { fn_name, args } => {
812 let args = args
813 .iter()
814 .map(|e| e.substitute_general::<T>(definitions))
815 .collect::<Result<Vec<Expr>, _>>()?;
816
817 Ok(Expr::call_extension_fn(fn_name.clone(), args))
818 }
819 ExprKind::GetAttr { expr, attr } => Ok(Expr::get_attr(
820 expr.substitute_general::<T>(definitions)?,
821 attr.clone(),
822 )),
823 ExprKind::HasAttr { expr, attr } => Ok(Expr::has_attr(
824 expr.substitute_general::<T>(definitions)?,
825 attr.clone(),
826 )),
827 ExprKind::Like { expr, pattern } => Ok(Expr::like(
828 expr.substitute_general::<T>(definitions)?,
829 pattern.clone(),
830 )),
831 ExprKind::Set(members) => {
832 let members = members
833 .iter()
834 .map(|e| e.substitute_general::<T>(definitions))
835 .collect::<Result<Vec<_>, _>>()?;
836 Ok(Expr::set(members))
837 }
838 ExprKind::Record(map) => {
839 let map = map
840 .iter()
841 .map(|(name, e)| Ok((name.clone(), e.substitute_general::<T>(definitions)?)))
842 .collect::<Result<BTreeMap<_, _>, _>>()?;
843 #[allow(clippy::expect_used)]
845 Ok(Expr::record(map)
846 .expect("cannot have a duplicate key because the input was already a BTreeMap"))
847 }
848 ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
849 expr.substitute_general::<T>(definitions)?,
850 entity_type.clone(),
851 )),
852 #[cfg(feature = "tolerant-ast")]
853 ExprKind::Error { .. } => Ok(self.clone()),
854 }
855 }
856}
857
858trait SubstitutionFunction {
860 type Err;
862 fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err>;
868}
869
870struct TypedSubstitution {}
871
872impl SubstitutionFunction for TypedSubstitution {
873 type Err = SubstitutionError;
874
875 fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
876 match (substitute, &value.type_annotation) {
877 (None, _) => Ok(Expr::unknown(value.clone())),
878 (Some(v), None) => Ok(v.clone().into()),
879 (Some(v), Some(t)) => {
880 if v.type_of() == *t {
881 Ok(v.clone().into())
882 } else {
883 Err(SubstitutionError::TypeError {
884 expected: t.clone(),
885 actual: v.type_of(),
886 })
887 }
888 }
889 }
890 }
891}
892
893struct UntypedSubstitution {}
894
895impl SubstitutionFunction for UntypedSubstitution {
896 type Err = std::convert::Infallible;
897
898 fn substitute(value: &Unknown, substitute: Option<&Value>) -> Result<Expr, Self::Err> {
899 Ok(substitute
900 .map(|v| v.clone().into())
901 .unwrap_or_else(|| Expr::unknown(value.clone())))
902 }
903}
904
905impl<T: Clone> std::fmt::Display for Expr<T> {
906 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
907 write!(f, "{}", &self.clone().into_expr::<crate::est::Builder>())
911 }
912}
913
914impl<T: Clone> BoundedDisplay for Expr<T> {
915 fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
916 BoundedDisplay::fmt(&self.clone().into_expr::<crate::est::Builder>(), f, n)
919 }
920}
921
922impl std::str::FromStr for Expr {
923 type Err = ParseErrors;
924
925 fn from_str(s: &str) -> Result<Expr, Self::Err> {
926 crate::parser::parse_expr(s)
927 }
928}
929
930#[derive(Debug, Clone, Diagnostic, Error)]
932pub enum SubstitutionError {
933 #[error("expected a value of type {expected}, got a value of type {actual}")]
935 TypeError {
936 expected: Type,
938 actual: Type,
940 },
941}
942
943#[derive(Hash, Debug, Clone, PartialEq, Eq)]
945pub struct Unknown {
946 pub name: SmolStr,
948 pub type_annotation: Option<Type>,
952}
953
954impl Unknown {
955 pub fn new_untyped(name: impl Into<SmolStr>) -> Self {
957 Self {
958 name: name.into(),
959 type_annotation: None,
960 }
961 }
962
963 pub fn new_with_type(name: impl Into<SmolStr>, ty: Type) -> Self {
966 Self {
967 name: name.into(),
968 type_annotation: Some(ty),
969 }
970 }
971}
972
973impl std::fmt::Display for Unknown {
974 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
975 write!(
978 f,
979 "{}",
980 Expr::unknown(self.clone()).into_expr::<crate::est::Builder>()
981 )
982 }
983}
984
985#[derive(Clone, Debug)]
988pub struct ExprBuilder<T> {
989 source_loc: Option<Loc>,
990 data: T,
991}
992
993impl<T: Default + Clone> expr_builder::ExprBuilder for ExprBuilder<T> {
994 type Expr = Expr<T>;
995
996 type Data = T;
997
998 #[cfg(feature = "tolerant-ast")]
999 type ErrorType = ParseErrors;
1000
1001 fn loc(&self) -> Option<&Loc> {
1002 self.source_loc.as_ref()
1003 }
1004
1005 fn data(&self) -> &Self::Data {
1006 &self.data
1007 }
1008
1009 fn with_data(data: T) -> Self {
1010 Self {
1011 source_loc: None,
1012 data,
1013 }
1014 }
1015
1016 fn with_maybe_source_loc(mut self, maybe_source_loc: Option<&Loc>) -> Self {
1017 self.source_loc = maybe_source_loc.cloned();
1018 self
1019 }
1020
1021 fn val(self, v: impl Into<Literal>) -> Expr<T> {
1025 self.with_expr_kind(ExprKind::Lit(v.into()))
1026 }
1027
1028 fn unknown(self, u: Unknown) -> Expr<T> {
1030 self.with_expr_kind(ExprKind::Unknown(u))
1031 }
1032
1033 fn var(self, v: Var) -> Expr<T> {
1035 self.with_expr_kind(ExprKind::Var(v))
1036 }
1037
1038 fn slot(self, s: SlotId) -> Expr<T> {
1040 self.with_expr_kind(ExprKind::Slot(s))
1041 }
1042
1043 fn ite(self, test_expr: Expr<T>, then_expr: Expr<T>, else_expr: Expr<T>) -> Expr<T> {
1047 self.with_expr_kind(ExprKind::If {
1048 test_expr: Arc::new(test_expr),
1049 then_expr: Arc::new(then_expr),
1050 else_expr: Arc::new(else_expr),
1051 })
1052 }
1053
1054 fn not(self, e: Expr<T>) -> Expr<T> {
1056 self.with_expr_kind(ExprKind::UnaryApp {
1057 op: UnaryOp::Not,
1058 arg: Arc::new(e),
1059 })
1060 }
1061
1062 fn is_eq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1064 self.with_expr_kind(ExprKind::BinaryApp {
1065 op: BinaryOp::Eq,
1066 arg1: Arc::new(e1),
1067 arg2: Arc::new(e2),
1068 })
1069 }
1070
1071 fn and(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1073 self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1074 (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1075 ExprKind::Lit(Literal::Bool(*b1 && *b2))
1076 }
1077 _ => ExprKind::And {
1078 left: Arc::new(e1),
1079 right: Arc::new(e2),
1080 },
1081 })
1082 }
1083
1084 fn or(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1086 self.with_expr_kind(match (&e1.expr_kind, &e2.expr_kind) {
1087 (ExprKind::Lit(Literal::Bool(b1)), ExprKind::Lit(Literal::Bool(b2))) => {
1088 ExprKind::Lit(Literal::Bool(*b1 || *b2))
1089 }
1090
1091 _ => ExprKind::Or {
1092 left: Arc::new(e1),
1093 right: Arc::new(e2),
1094 },
1095 })
1096 }
1097
1098 fn less(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1100 self.with_expr_kind(ExprKind::BinaryApp {
1101 op: BinaryOp::Less,
1102 arg1: Arc::new(e1),
1103 arg2: Arc::new(e2),
1104 })
1105 }
1106
1107 fn lesseq(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1109 self.with_expr_kind(ExprKind::BinaryApp {
1110 op: BinaryOp::LessEq,
1111 arg1: Arc::new(e1),
1112 arg2: Arc::new(e2),
1113 })
1114 }
1115
1116 fn add(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1118 self.with_expr_kind(ExprKind::BinaryApp {
1119 op: BinaryOp::Add,
1120 arg1: Arc::new(e1),
1121 arg2: Arc::new(e2),
1122 })
1123 }
1124
1125 fn sub(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1127 self.with_expr_kind(ExprKind::BinaryApp {
1128 op: BinaryOp::Sub,
1129 arg1: Arc::new(e1),
1130 arg2: Arc::new(e2),
1131 })
1132 }
1133
1134 fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1136 self.with_expr_kind(ExprKind::BinaryApp {
1137 op: BinaryOp::Mul,
1138 arg1: Arc::new(e1),
1139 arg2: Arc::new(e2),
1140 })
1141 }
1142
1143 fn neg(self, e: Expr<T>) -> Expr<T> {
1145 self.with_expr_kind(ExprKind::UnaryApp {
1146 op: UnaryOp::Neg,
1147 arg: Arc::new(e),
1148 })
1149 }
1150
1151 fn is_in(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1155 self.with_expr_kind(ExprKind::BinaryApp {
1156 op: BinaryOp::In,
1157 arg1: Arc::new(e1),
1158 arg2: Arc::new(e2),
1159 })
1160 }
1161
1162 fn contains(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1165 self.with_expr_kind(ExprKind::BinaryApp {
1166 op: BinaryOp::Contains,
1167 arg1: Arc::new(e1),
1168 arg2: Arc::new(e2),
1169 })
1170 }
1171
1172 fn contains_all(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1174 self.with_expr_kind(ExprKind::BinaryApp {
1175 op: BinaryOp::ContainsAll,
1176 arg1: Arc::new(e1),
1177 arg2: Arc::new(e2),
1178 })
1179 }
1180
1181 fn contains_any(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
1183 self.with_expr_kind(ExprKind::BinaryApp {
1184 op: BinaryOp::ContainsAny,
1185 arg1: Arc::new(e1),
1186 arg2: Arc::new(e2),
1187 })
1188 }
1189
1190 fn is_empty(self, expr: Expr<T>) -> Expr<T> {
1192 self.with_expr_kind(ExprKind::UnaryApp {
1193 op: UnaryOp::IsEmpty,
1194 arg: Arc::new(expr),
1195 })
1196 }
1197
1198 fn get_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1201 self.with_expr_kind(ExprKind::BinaryApp {
1202 op: BinaryOp::GetTag,
1203 arg1: Arc::new(expr),
1204 arg2: Arc::new(tag),
1205 })
1206 }
1207
1208 fn has_tag(self, expr: Expr<T>, tag: Expr<T>) -> Expr<T> {
1211 self.with_expr_kind(ExprKind::BinaryApp {
1212 op: BinaryOp::HasTag,
1213 arg1: Arc::new(expr),
1214 arg2: Arc::new(tag),
1215 })
1216 }
1217
1218 fn set(self, exprs: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1220 self.with_expr_kind(ExprKind::Set(Arc::new(exprs.into_iter().collect())))
1221 }
1222
1223 fn record(
1225 self,
1226 pairs: impl IntoIterator<Item = (SmolStr, Expr<T>)>,
1227 ) -> Result<Expr<T>, ExpressionConstructionError> {
1228 let mut map = BTreeMap::new();
1229 for (k, v) in pairs {
1230 match map.entry(k) {
1231 btree_map::Entry::Occupied(oentry) => {
1232 return Err(expression_construction_errors::DuplicateKeyError {
1233 key: oentry.key().clone(),
1234 context: "in record literal",
1235 }
1236 .into());
1237 }
1238 btree_map::Entry::Vacant(ventry) => {
1239 ventry.insert(v);
1240 }
1241 }
1242 }
1243 Ok(self.with_expr_kind(ExprKind::Record(Arc::new(map))))
1244 }
1245
1246 fn call_extension_fn(self, fn_name: Name, args: impl IntoIterator<Item = Expr<T>>) -> Expr<T> {
1249 self.with_expr_kind(ExprKind::ExtensionFunctionApp {
1250 fn_name,
1251 args: Arc::new(args.into_iter().collect()),
1252 })
1253 }
1254
1255 fn unary_app(self, op: impl Into<UnaryOp>, arg: Expr<T>) -> Expr<T> {
1258 self.with_expr_kind(ExprKind::UnaryApp {
1259 op: op.into(),
1260 arg: Arc::new(arg),
1261 })
1262 }
1263
1264 fn binary_app(self, op: impl Into<BinaryOp>, arg1: Expr<T>, arg2: Expr<T>) -> Expr<T> {
1267 self.with_expr_kind(ExprKind::BinaryApp {
1268 op: op.into(),
1269 arg1: Arc::new(arg1),
1270 arg2: Arc::new(arg2),
1271 })
1272 }
1273
1274 fn get_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1278 self.with_expr_kind(ExprKind::GetAttr {
1279 expr: Arc::new(expr),
1280 attr,
1281 })
1282 }
1283
1284 fn has_attr(self, expr: Expr<T>, attr: SmolStr) -> Expr<T> {
1289 self.with_expr_kind(ExprKind::HasAttr {
1290 expr: Arc::new(expr),
1291 attr,
1292 })
1293 }
1294
1295 fn like(self, expr: Expr<T>, pattern: Pattern) -> Expr<T> {
1299 self.with_expr_kind(ExprKind::Like {
1300 expr: Arc::new(expr),
1301 pattern,
1302 })
1303 }
1304
1305 fn is_entity_type(self, expr: Expr<T>, entity_type: EntityType) -> Expr<T> {
1307 self.with_expr_kind(ExprKind::Is {
1308 expr: Arc::new(expr),
1309 entity_type,
1310 })
1311 }
1312
1313 #[cfg(feature = "tolerant-ast")]
1315 fn error(self, parse_errors: ParseErrors) -> Result<Self::Expr, Self::ErrorType> {
1316 Err(parse_errors)
1317 }
1318}
1319
1320impl<T> ExprBuilder<T> {
1321 pub fn with_expr_kind(self, expr_kind: ExprKind<T>) -> Expr<T> {
1324 Expr::new(expr_kind, self.source_loc, self.data)
1325 }
1326
1327 pub fn ite_arc(
1331 self,
1332 test_expr: Arc<Expr<T>>,
1333 then_expr: Arc<Expr<T>>,
1334 else_expr: Arc<Expr<T>>,
1335 ) -> Expr<T> {
1336 self.with_expr_kind(ExprKind::If {
1337 test_expr,
1338 then_expr,
1339 else_expr,
1340 })
1341 }
1342
1343 pub fn record_arc(self, map: Arc<BTreeMap<SmolStr, Expr<T>>>) -> Expr<T> {
1350 self.with_expr_kind(ExprKind::Record(map))
1351 }
1352}
1353
1354impl<T: Clone + Default> ExprBuilder<T> {
1355 pub fn with_same_source_loc<U>(self, expr: &Expr<U>) -> Self {
1359 self.with_maybe_source_loc(expr.source_loc.as_ref())
1360 }
1361}
1362
1363#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1369pub enum ExpressionConstructionError {
1370 #[error(transparent)]
1372 #[diagnostic(transparent)]
1373 DuplicateKey(#[from] expression_construction_errors::DuplicateKeyError),
1374}
1375
1376pub mod expression_construction_errors {
1378 use miette::Diagnostic;
1379 use smol_str::SmolStr;
1380 use thiserror::Error;
1381
1382 #[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
1388 #[error("duplicate key `{key}` {context}")]
1389 pub struct DuplicateKeyError {
1390 pub(crate) key: SmolStr,
1392 pub(crate) context: &'static str,
1394 }
1395
1396 impl DuplicateKeyError {
1397 pub fn key(&self) -> &str {
1399 &self.key
1400 }
1401
1402 pub(crate) fn with_context(self, context: &'static str) -> Self {
1404 Self { context, ..self }
1405 }
1406 }
1407}
1408
1409#[derive(Debug, Clone)]
1413pub struct ExprShapeOnly<'a, T: Clone = ()>(Cow<'a, Expr<T>>);
1414
1415impl<'a, T: Clone> ExprShapeOnly<'a, T> {
1416 pub fn new_from_borrowed(e: &'a Expr<T>) -> ExprShapeOnly<'a, T> {
1420 ExprShapeOnly(Cow::Borrowed(e))
1421 }
1422
1423 pub fn new_from_owned(e: Expr<T>) -> ExprShapeOnly<'a, T> {
1427 ExprShapeOnly(Cow::Owned(e))
1428 }
1429}
1430
1431impl<T: Clone> PartialEq for ExprShapeOnly<'_, T> {
1432 fn eq(&self, other: &Self) -> bool {
1433 self.0.eq_shape(&other.0)
1434 }
1435}
1436
1437impl<T: Clone> Eq for ExprShapeOnly<'_, T> {}
1438
1439impl<T: Clone> Hash for ExprShapeOnly<'_, T> {
1440 fn hash<H: Hasher>(&self, state: &mut H) {
1441 self.0.hash_shape(state);
1442 }
1443}
1444
1445impl<T: Clone> PartialOrd for ExprShapeOnly<'_, T> {
1446 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1447 Some(self.cmp(other))
1448 }
1449}
1450
1451impl<T: Clone> Ord for ExprShapeOnly<'_, T> {
1452 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1453 self.0.cmp_shape(&other.0)
1454 }
1455}
1456
1457impl<T> Expr<T> {
1458 pub fn eq_shape<U>(&self, other: &Expr<U>) -> bool {
1465 use ExprKind::*;
1466 match (self.expr_kind(), other.expr_kind()) {
1467 (Lit(lit), Lit(lit1)) => lit == lit1,
1468 (Var(v), Var(v1)) => v == v1,
1469 (Slot(s), Slot(s1)) => s == s1,
1470 (
1471 Unknown(self::Unknown {
1472 name: name1,
1473 type_annotation: ta_1,
1474 }),
1475 Unknown(self::Unknown {
1476 name: name2,
1477 type_annotation: ta_2,
1478 }),
1479 ) => (name1 == name2) && (ta_1 == ta_2),
1480 (
1481 If {
1482 test_expr,
1483 then_expr,
1484 else_expr,
1485 },
1486 If {
1487 test_expr: test_expr1,
1488 then_expr: then_expr1,
1489 else_expr: else_expr1,
1490 },
1491 ) => {
1492 test_expr.eq_shape(test_expr1)
1493 && then_expr.eq_shape(then_expr1)
1494 && else_expr.eq_shape(else_expr1)
1495 }
1496 (
1497 And { left, right },
1498 And {
1499 left: left1,
1500 right: right1,
1501 },
1502 )
1503 | (
1504 Or { left, right },
1505 Or {
1506 left: left1,
1507 right: right1,
1508 },
1509 ) => left.eq_shape(left1) && right.eq_shape(right1),
1510 (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1511 op == op1 && arg.eq_shape(arg1)
1512 }
1513 (
1514 BinaryApp { op, arg1, arg2 },
1515 BinaryApp {
1516 op: op1,
1517 arg1: arg11,
1518 arg2: arg21,
1519 },
1520 ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
1521 (
1522 ExtensionFunctionApp { fn_name, args },
1523 ExtensionFunctionApp {
1524 fn_name: fn_name1,
1525 args: args1,
1526 },
1527 ) => {
1528 fn_name == fn_name1
1529 && args.len() == args1.len()
1530 && args.iter().zip(args1.iter()).all(|(a, a1)| a.eq_shape(a1))
1531 }
1532 (
1533 GetAttr { expr, attr },
1534 GetAttr {
1535 expr: expr1,
1536 attr: attr1,
1537 },
1538 )
1539 | (
1540 HasAttr { expr, attr },
1541 HasAttr {
1542 expr: expr1,
1543 attr: attr1,
1544 },
1545 ) => attr == attr1 && expr.eq_shape(expr1),
1546 (
1547 Like { expr, pattern },
1548 Like {
1549 expr: expr1,
1550 pattern: pattern1,
1551 },
1552 ) => pattern == pattern1 && expr.eq_shape(expr1),
1553 (Set(elems), Set(elems1)) => {
1554 elems.len() == elems1.len()
1555 && elems
1556 .iter()
1557 .zip(elems1.iter())
1558 .all(|(e, e1)| e.eq_shape(e1))
1559 }
1560 (Record(map), Record(map1)) => {
1561 map.len() == map1.len()
1562 && map
1563 .iter()
1564 .zip(map1.iter()) .all(|((a, e), (a1, e1))| a == a1 && e.eq_shape(e1))
1566 }
1567 (
1568 Is { expr, entity_type },
1569 Is {
1570 expr: expr1,
1571 entity_type: entity_type1,
1572 },
1573 ) => entity_type == entity_type1 && expr.eq_shape(expr1),
1574 _ => false,
1575 }
1576 }
1577
1578 pub fn hash_shape<H>(&self, state: &mut H)
1582 where
1583 H: Hasher,
1584 {
1585 mem::discriminant(self).hash(state);
1586 match self.expr_kind() {
1587 ExprKind::Lit(lit) => lit.hash(state),
1588 ExprKind::Var(v) => v.hash(state),
1589 ExprKind::Slot(s) => s.hash(state),
1590 ExprKind::Unknown(u) => u.hash(state),
1591 ExprKind::If {
1592 test_expr,
1593 then_expr,
1594 else_expr,
1595 } => {
1596 test_expr.hash_shape(state);
1597 then_expr.hash_shape(state);
1598 else_expr.hash_shape(state);
1599 }
1600 ExprKind::And { left, right } => {
1601 left.hash_shape(state);
1602 right.hash_shape(state);
1603 }
1604 ExprKind::Or { left, right } => {
1605 left.hash_shape(state);
1606 right.hash_shape(state);
1607 }
1608 ExprKind::UnaryApp { op, arg } => {
1609 op.hash(state);
1610 arg.hash_shape(state);
1611 }
1612 ExprKind::BinaryApp { op, arg1, arg2 } => {
1613 op.hash(state);
1614 arg1.hash_shape(state);
1615 arg2.hash_shape(state);
1616 }
1617 ExprKind::ExtensionFunctionApp { fn_name, args } => {
1618 fn_name.hash(state);
1619 state.write_usize(args.len());
1620 args.iter().for_each(|a| {
1621 a.hash_shape(state);
1622 });
1623 }
1624 ExprKind::GetAttr { expr, attr } => {
1625 expr.hash_shape(state);
1626 attr.hash(state);
1627 }
1628 ExprKind::HasAttr { expr, attr } => {
1629 expr.hash_shape(state);
1630 attr.hash(state);
1631 }
1632 ExprKind::Like { expr, pattern } => {
1633 expr.hash_shape(state);
1634 pattern.hash(state);
1635 }
1636 ExprKind::Set(elems) => {
1637 state.write_usize(elems.len());
1638 elems.iter().for_each(|e| {
1639 e.hash_shape(state);
1640 })
1641 }
1642 ExprKind::Record(map) => {
1643 state.write_usize(map.len());
1644 map.iter().for_each(|(s, a)| {
1645 s.hash(state);
1646 a.hash_shape(state);
1647 });
1648 }
1649 ExprKind::Is { expr, entity_type } => {
1650 expr.hash_shape(state);
1651 entity_type.hash(state);
1652 }
1653 #[cfg(feature = "tolerant-ast")]
1654 ExprKind::Error { error_kind, .. } => error_kind.hash(state),
1655 }
1656 }
1657
1658 pub fn cmp_shape(&self, other: &Expr<T>) -> std::cmp::Ordering {
1662 let self_kind = self.expr_kind();
1664 let other_kind = other.expr_kind();
1665 if std::mem::discriminant(self_kind) != std::mem::discriminant(other_kind) {
1666 return self_kind.variant_order().cmp(&other_kind.variant_order());
1667 }
1668
1669 use ExprKind::*;
1671 match (self_kind, other_kind) {
1672 (Lit(lit), Lit(lit1)) => lit.cmp(lit1),
1673 (Var(v), Var(v1)) => v.cmp(v1),
1674 (Slot(s), Slot(s1)) => s.cmp(s1),
1675 (
1676 Unknown(self::Unknown {
1677 name: name1,
1678 type_annotation: ta_1,
1679 }),
1680 Unknown(self::Unknown {
1681 name: name2,
1682 type_annotation: ta_2,
1683 }),
1684 ) => name1.cmp(name2).then_with(|| ta_1.cmp(ta_2)),
1685 (
1686 If {
1687 test_expr,
1688 then_expr,
1689 else_expr,
1690 },
1691 If {
1692 test_expr: test_expr1,
1693 then_expr: then_expr1,
1694 else_expr: else_expr1,
1695 },
1696 ) => test_expr
1697 .cmp_shape(test_expr1)
1698 .then_with(|| then_expr.cmp_shape(then_expr1))
1699 .then_with(|| else_expr.cmp_shape(else_expr1)),
1700 (
1701 And { left, right },
1702 And {
1703 left: left1,
1704 right: right1,
1705 },
1706 ) => left.cmp_shape(left1).then_with(|| right.cmp_shape(right1)),
1707 (
1708 Or { left, right },
1709 Or {
1710 left: left1,
1711 right: right1,
1712 },
1713 ) => left.cmp_shape(left1).then_with(|| right.cmp_shape(right1)),
1714 (UnaryApp { op, arg }, UnaryApp { op: op1, arg: arg1 }) => {
1715 op.cmp(op1).then_with(|| arg.cmp_shape(arg1))
1716 }
1717 (
1718 BinaryApp { op, arg1, arg2 },
1719 BinaryApp {
1720 op: op1,
1721 arg1: arg11,
1722 arg2: arg21,
1723 },
1724 ) => op
1725 .cmp(op1)
1726 .then_with(|| arg1.cmp_shape(arg11))
1727 .then_with(|| arg2.cmp_shape(arg21)),
1728 (
1729 ExtensionFunctionApp { fn_name, args },
1730 ExtensionFunctionApp {
1731 fn_name: fn_name1,
1732 args: args1,
1733 },
1734 ) => fn_name.cmp(fn_name1).then_with(|| {
1735 args.len().cmp(&args1.len()).then_with(|| {
1736 for (a, a1) in args.iter().zip(args1.iter()) {
1737 match a.cmp_shape(a1) {
1738 std::cmp::Ordering::Equal => continue,
1739 other => return other,
1740 }
1741 }
1742 std::cmp::Ordering::Equal
1743 })
1744 }),
1745 (
1746 GetAttr { expr, attr },
1747 GetAttr {
1748 expr: expr1,
1749 attr: attr1,
1750 },
1751 ) => attr.cmp(attr1).then_with(|| expr.cmp_shape(expr1)),
1752 (
1753 HasAttr { expr, attr },
1754 HasAttr {
1755 expr: expr1,
1756 attr: attr1,
1757 },
1758 ) => attr.cmp(attr1).then_with(|| expr.cmp_shape(expr1)),
1759 (
1760 Like { expr, pattern },
1761 Like {
1762 expr: expr1,
1763 pattern: pattern1,
1764 },
1765 ) => pattern.cmp(pattern1).then_with(|| expr.cmp_shape(expr1)),
1766 (Set(elems), Set(elems1)) => elems.len().cmp(&elems1.len()).then_with(|| {
1767 for (e, e1) in elems.iter().zip(elems1.iter()) {
1768 match e.cmp_shape(e1) {
1769 std::cmp::Ordering::Equal => continue,
1770 other => return other,
1771 }
1772 }
1773 std::cmp::Ordering::Equal
1774 }),
1775 (Record(map), Record(map1)) => map.len().cmp(&map1.len()).then_with(|| {
1776 for ((a, e), (a1, e1)) in map.iter().zip(map1.iter()) {
1777 match a.cmp(a1).then_with(|| e.cmp_shape(e1)) {
1778 std::cmp::Ordering::Equal => continue,
1779 other => return other,
1780 }
1781 }
1782 std::cmp::Ordering::Equal
1783 }),
1784 (
1785 Is { expr, entity_type },
1786 Is {
1787 expr: expr1,
1788 entity_type: entity_type1,
1789 },
1790 ) => entity_type
1791 .cmp(entity_type1)
1792 .then_with(|| expr.cmp_shape(expr1)),
1793 #[cfg(feature = "tolerant-ast")]
1794 (
1795 Error { error_kind },
1796 Error {
1797 error_kind: error_kind1,
1798 },
1799 ) => error_kind.cmp(error_kind1),
1800 #[allow(clippy::unreachable)]
1802 _ => unreachable!(
1803 "Different variants should have been handled by variant_order comparison"
1804 ),
1805 }
1806 }
1807}
1808
1809#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
1811#[serde(rename_all = "camelCase")]
1812#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1813#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
1814#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
1815pub enum Var {
1816 Principal,
1818 Action,
1820 Resource,
1822 Context,
1824}
1825
1826impl From<PrincipalOrResource> for Var {
1827 fn from(v: PrincipalOrResource) -> Self {
1828 match v {
1829 PrincipalOrResource::Principal => Var::Principal,
1830 PrincipalOrResource::Resource => Var::Resource,
1831 }
1832 }
1833}
1834
1835#[allow(clippy::fallible_impl_from)]
1837impl From<Var> for Id {
1838 fn from(var: Var) -> Self {
1839 #[allow(clippy::unwrap_used)]
1841 format!("{var}").parse().unwrap()
1842 }
1843}
1844
1845#[allow(clippy::fallible_impl_from)]
1847impl From<Var> for UnreservedId {
1848 fn from(var: Var) -> Self {
1849 #[allow(clippy::unwrap_used)]
1851 Id::from(var).try_into().unwrap()
1852 }
1853}
1854
1855impl std::fmt::Display for Var {
1856 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1857 match self {
1858 Self::Principal => write!(f, "principal"),
1859 Self::Action => write!(f, "action"),
1860 Self::Resource => write!(f, "resource"),
1861 Self::Context => write!(f, "context"),
1862 }
1863 }
1864}
1865
1866#[cfg(test)]
1867mod test {
1868 use cool_asserts::assert_matches;
1869 use itertools::Itertools;
1870 use smol_str::ToSmolStr;
1871 use std::collections::{hash_map::DefaultHasher, HashSet};
1872
1873 use crate::expr_builder::ExprBuilder as _;
1874
1875 use super::*;
1876
1877 pub fn all_vars() -> impl Iterator<Item = Var> {
1878 [Var::Principal, Var::Action, Var::Resource, Var::Context].into_iter()
1879 }
1880
1881 #[test]
1883 fn all_vars_are_ids() {
1884 for var in all_vars() {
1885 let _id: Id = var.into();
1886 let _id: UnreservedId = var.into();
1887 }
1888 }
1889
1890 #[test]
1891 fn exprs() {
1892 assert_eq!(
1893 Expr::val(33),
1894 Expr::new(ExprKind::Lit(Literal::Long(33)), None, ())
1895 );
1896 assert_eq!(
1897 Expr::val("hello"),
1898 Expr::new(ExprKind::Lit(Literal::from("hello")), None, ())
1899 );
1900 assert_eq!(
1901 Expr::val(EntityUID::with_eid("foo")),
1902 Expr::new(
1903 ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1904 None,
1905 ()
1906 )
1907 );
1908 assert_eq!(
1909 Expr::var(Var::Principal),
1910 Expr::new(ExprKind::Var(Var::Principal), None, ())
1911 );
1912 assert_eq!(
1913 Expr::ite(Expr::val(true), Expr::val(88), Expr::val(-100)),
1914 Expr::new(
1915 ExprKind::If {
1916 test_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(true)), None, ())),
1917 then_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(88)), None, ())),
1918 else_expr: Arc::new(Expr::new(ExprKind::Lit(Literal::Long(-100)), None, ())),
1919 },
1920 None,
1921 ()
1922 )
1923 );
1924 assert_eq!(
1925 Expr::not(Expr::val(false)),
1926 Expr::new(
1927 ExprKind::UnaryApp {
1928 op: UnaryOp::Not,
1929 arg: Arc::new(Expr::new(ExprKind::Lit(Literal::Bool(false)), None, ())),
1930 },
1931 None,
1932 ()
1933 )
1934 );
1935 assert_eq!(
1936 Expr::get_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1937 Expr::new(
1938 ExprKind::GetAttr {
1939 expr: Arc::new(Expr::new(
1940 ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1941 None,
1942 ()
1943 )),
1944 attr: "some_attr".into()
1945 },
1946 None,
1947 ()
1948 )
1949 );
1950 assert_eq!(
1951 Expr::has_attr(Expr::val(EntityUID::with_eid("foo")), "some_attr".into()),
1952 Expr::new(
1953 ExprKind::HasAttr {
1954 expr: Arc::new(Expr::new(
1955 ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1956 None,
1957 ()
1958 )),
1959 attr: "some_attr".into()
1960 },
1961 None,
1962 ()
1963 )
1964 );
1965 assert_eq!(
1966 Expr::is_entity_type(
1967 Expr::val(EntityUID::with_eid("foo")),
1968 "Type".parse().unwrap()
1969 ),
1970 Expr::new(
1971 ExprKind::Is {
1972 expr: Arc::new(Expr::new(
1973 ExprKind::Lit(Literal::from(EntityUID::with_eid("foo"))),
1974 None,
1975 ()
1976 )),
1977 entity_type: "Type".parse().unwrap()
1978 },
1979 None,
1980 ()
1981 ),
1982 );
1983 }
1984
1985 #[test]
1986 fn like_display() {
1987 let e = Expr::like(Expr::val("a"), Pattern::from(vec![PatternElem::Char('\0')]));
1989 assert_eq!(format!("{e}"), r#""a" like "\0""#);
1990 let e = Expr::like(
1992 Expr::val("a"),
1993 Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('0')]),
1994 );
1995 assert_eq!(format!("{e}"), r#""a" like "\\0""#);
1996 let e = Expr::like(
1998 Expr::val("a"),
1999 Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Wildcard]),
2000 );
2001 assert_eq!(format!("{e}"), r#""a" like "\\*""#);
2002 let e = Expr::like(
2004 Expr::val("a"),
2005 Pattern::from(vec![PatternElem::Char('\\'), PatternElem::Char('*')]),
2006 );
2007 assert_eq!(format!("{e}"), r#""a" like "\\\*""#);
2008 }
2009
2010 #[test]
2011 fn has_display() {
2012 let e = Expr::has_attr(Expr::val("a"), "\0".into());
2014 assert_eq!(format!("{e}"), r#""a" has "\0""#);
2015 let e = Expr::has_attr(Expr::val("a"), r"\".into());
2017 assert_eq!(format!("{e}"), r#""a" has "\\""#);
2018 }
2019
2020 #[test]
2021 fn slot_display() {
2022 let e = Expr::slot(SlotId::principal());
2023 assert_eq!(format!("{e}"), "?principal");
2024 let e = Expr::slot(SlotId::resource());
2025 assert_eq!(format!("{e}"), "?resource");
2026 let e = Expr::val(EntityUID::with_eid("eid"));
2027 assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
2028 }
2029
2030 #[test]
2031 fn simple_slots() {
2032 let e = Expr::slot(SlotId::principal());
2033 let p = SlotId::principal();
2034 let r = SlotId::resource();
2035 let set: HashSet<SlotId> = HashSet::from_iter([p]);
2036 assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
2037 let e = Expr::or(
2038 Expr::slot(SlotId::principal()),
2039 Expr::ite(
2040 Expr::val(true),
2041 Expr::slot(SlotId::resource()),
2042 Expr::val(false),
2043 ),
2044 );
2045 let set: HashSet<SlotId> = HashSet::from_iter([p, r]);
2046 assert_eq!(set, e.slots().map(|slot| slot.id).collect::<HashSet<_>>());
2047 }
2048
2049 #[test]
2050 fn unknowns() {
2051 let e = Expr::ite(
2052 Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
2053 Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
2054 Expr::unknown(Unknown::new_untyped("c")),
2055 );
2056 let unknowns = e.unknowns().collect_vec();
2057 assert_eq!(unknowns.len(), 3);
2058 assert!(unknowns.contains(&&Unknown::new_untyped("a")));
2059 assert!(unknowns.contains(&&Unknown::new_untyped("b")));
2060 assert!(unknowns.contains(&&Unknown::new_untyped("c")));
2061 }
2062
2063 #[test]
2064 fn is_unknown() {
2065 let e = Expr::ite(
2066 Expr::not(Expr::unknown(Unknown::new_untyped("a"))),
2067 Expr::and(Expr::unknown(Unknown::new_untyped("b")), Expr::val(3)),
2068 Expr::unknown(Unknown::new_untyped("c")),
2069 );
2070 assert!(e.contains_unknown());
2071 let e = Expr::ite(
2072 Expr::not(Expr::val(true)),
2073 Expr::and(Expr::val(1), Expr::val(3)),
2074 Expr::val(1),
2075 );
2076 assert!(!e.contains_unknown());
2077 }
2078
2079 #[test]
2080 fn expr_with_data() {
2081 let e = ExprBuilder::with_data("data").val(1);
2082 assert_eq!(e.into_data(), "data");
2083 }
2084
2085 #[test]
2086 fn expr_shape_only_eq() {
2087 let temp = ExprBuilder::with_data(1).val(1);
2088 let exprs = &[
2089 (ExprBuilder::with_data(1).val(33), Expr::val(33)),
2090 (ExprBuilder::with_data(1).val(true), Expr::val(true)),
2091 (
2092 ExprBuilder::with_data(1).var(Var::Principal),
2093 Expr::var(Var::Principal),
2094 ),
2095 (
2096 ExprBuilder::with_data(1).slot(SlotId::principal()),
2097 Expr::slot(SlotId::principal()),
2098 ),
2099 (
2100 ExprBuilder::with_data(1).ite(temp.clone(), temp.clone(), temp.clone()),
2101 Expr::ite(Expr::val(1), Expr::val(1), Expr::val(1)),
2102 ),
2103 (
2104 ExprBuilder::with_data(1).not(temp.clone()),
2105 Expr::not(Expr::val(1)),
2106 ),
2107 (
2108 ExprBuilder::with_data(1).is_eq(temp.clone(), temp.clone()),
2109 Expr::is_eq(Expr::val(1), Expr::val(1)),
2110 ),
2111 (
2112 ExprBuilder::with_data(1).and(temp.clone(), temp.clone()),
2113 Expr::and(Expr::val(1), Expr::val(1)),
2114 ),
2115 (
2116 ExprBuilder::with_data(1).or(temp.clone(), temp.clone()),
2117 Expr::or(Expr::val(1), Expr::val(1)),
2118 ),
2119 (
2120 ExprBuilder::with_data(1).less(temp.clone(), temp.clone()),
2121 Expr::less(Expr::val(1), Expr::val(1)),
2122 ),
2123 (
2124 ExprBuilder::with_data(1).lesseq(temp.clone(), temp.clone()),
2125 Expr::lesseq(Expr::val(1), Expr::val(1)),
2126 ),
2127 (
2128 ExprBuilder::with_data(1).greater(temp.clone(), temp.clone()),
2129 Expr::greater(Expr::val(1), Expr::val(1)),
2130 ),
2131 (
2132 ExprBuilder::with_data(1).greatereq(temp.clone(), temp.clone()),
2133 Expr::greatereq(Expr::val(1), Expr::val(1)),
2134 ),
2135 (
2136 ExprBuilder::with_data(1).add(temp.clone(), temp.clone()),
2137 Expr::add(Expr::val(1), Expr::val(1)),
2138 ),
2139 (
2140 ExprBuilder::with_data(1).sub(temp.clone(), temp.clone()),
2141 Expr::sub(Expr::val(1), Expr::val(1)),
2142 ),
2143 (
2144 ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
2145 Expr::mul(Expr::val(1), Expr::val(1)),
2146 ),
2147 (
2148 ExprBuilder::with_data(1).neg(temp.clone()),
2149 Expr::neg(Expr::val(1)),
2150 ),
2151 (
2152 ExprBuilder::with_data(1).is_in(temp.clone(), temp.clone()),
2153 Expr::is_in(Expr::val(1), Expr::val(1)),
2154 ),
2155 (
2156 ExprBuilder::with_data(1).contains(temp.clone(), temp.clone()),
2157 Expr::contains(Expr::val(1), Expr::val(1)),
2158 ),
2159 (
2160 ExprBuilder::with_data(1).contains_all(temp.clone(), temp.clone()),
2161 Expr::contains_all(Expr::val(1), Expr::val(1)),
2162 ),
2163 (
2164 ExprBuilder::with_data(1).contains_any(temp.clone(), temp.clone()),
2165 Expr::contains_any(Expr::val(1), Expr::val(1)),
2166 ),
2167 (
2168 ExprBuilder::with_data(1).is_empty(temp.clone()),
2169 Expr::is_empty(Expr::val(1)),
2170 ),
2171 (
2172 ExprBuilder::with_data(1).set([temp.clone()]),
2173 Expr::set([Expr::val(1)]),
2174 ),
2175 (
2176 ExprBuilder::with_data(1)
2177 .record([("foo".into(), temp.clone())])
2178 .unwrap(),
2179 Expr::record([("foo".into(), Expr::val(1))]).unwrap(),
2180 ),
2181 (
2182 ExprBuilder::with_data(1)
2183 .call_extension_fn("foo".parse().unwrap(), vec![temp.clone()]),
2184 Expr::call_extension_fn("foo".parse().unwrap(), vec![Expr::val(1)]),
2185 ),
2186 (
2187 ExprBuilder::with_data(1).get_attr(temp.clone(), "foo".into()),
2188 Expr::get_attr(Expr::val(1), "foo".into()),
2189 ),
2190 (
2191 ExprBuilder::with_data(1).has_attr(temp.clone(), "foo".into()),
2192 Expr::has_attr(Expr::val(1), "foo".into()),
2193 ),
2194 (
2195 ExprBuilder::with_data(1)
2196 .like(temp.clone(), Pattern::from(vec![PatternElem::Wildcard])),
2197 Expr::like(Expr::val(1), Pattern::from(vec![PatternElem::Wildcard])),
2198 ),
2199 (
2200 ExprBuilder::with_data(1).is_entity_type(temp, "T".parse().unwrap()),
2201 Expr::is_entity_type(Expr::val(1), "T".parse().unwrap()),
2202 ),
2203 ];
2204
2205 for (e0, e1) in exprs {
2206 assert!(e0.eq_shape(e0));
2207 assert!(e1.eq_shape(e1));
2208 assert!(e0.eq_shape(e1));
2209 assert!(e1.eq_shape(e0));
2210
2211 let mut hasher0 = DefaultHasher::new();
2212 e0.hash_shape(&mut hasher0);
2213 let hash0 = hasher0.finish();
2214
2215 let mut hasher1 = DefaultHasher::new();
2216 e1.hash_shape(&mut hasher1);
2217 let hash1 = hasher1.finish();
2218
2219 assert_eq!(hash0, hash1);
2220 }
2221 }
2222
2223 #[test]
2224 fn expr_shape_only_not_eq() {
2225 let expr1 = ExprBuilder::with_data(1).val(1);
2226 let expr2 = ExprBuilder::with_data(1).val(2);
2227 assert_ne!(
2228 ExprShapeOnly::new_from_borrowed(&expr1),
2229 ExprShapeOnly::new_from_borrowed(&expr2)
2230 );
2231 }
2232
2233 #[test]
2234 fn expr_shape_only_set_prefix_ne() {
2235 let e1 = ExprShapeOnly::new_from_owned(Expr::set([]));
2236 let e2 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1)]));
2237 let e3 = ExprShapeOnly::new_from_owned(Expr::set([Expr::val(1), Expr::val(2)]));
2238
2239 assert_ne!(e1, e2);
2240 assert_ne!(e1, e3);
2241 assert_ne!(e2, e1);
2242 assert_ne!(e2, e3);
2243 assert_ne!(e3, e1);
2244 assert_ne!(e2, e1);
2245 }
2246
2247 #[test]
2248 fn expr_shape_only_ext_fn_arg_prefix_ne() {
2249 let e1 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2250 "decimal".parse().unwrap(),
2251 vec![],
2252 ));
2253 let e2 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2254 "decimal".parse().unwrap(),
2255 vec![Expr::val("0.0")],
2256 ));
2257 let e3 = ExprShapeOnly::new_from_owned(Expr::call_extension_fn(
2258 "decimal".parse().unwrap(),
2259 vec![Expr::val("0.0"), Expr::val("0.0")],
2260 ));
2261
2262 assert_ne!(e1, e2);
2263 assert_ne!(e1, e3);
2264 assert_ne!(e2, e1);
2265 assert_ne!(e2, e3);
2266 assert_ne!(e3, e1);
2267 assert_ne!(e2, e1);
2268 }
2269
2270 #[test]
2271 fn expr_shape_only_record_attr_prefix_ne() {
2272 let e1 = ExprShapeOnly::new_from_owned(Expr::record([]).unwrap());
2273 let e2 = ExprShapeOnly::new_from_owned(
2274 Expr::record([("a".to_smolstr(), Expr::val(1))]).unwrap(),
2275 );
2276 let e3 = ExprShapeOnly::new_from_owned(
2277 Expr::record([
2278 ("a".to_smolstr(), Expr::val(1)),
2279 ("b".to_smolstr(), Expr::val(2)),
2280 ])
2281 .unwrap(),
2282 );
2283
2284 assert_ne!(e1, e2);
2285 assert_ne!(e1, e3);
2286 assert_ne!(e2, e1);
2287 assert_ne!(e2, e3);
2288 assert_ne!(e3, e1);
2289 assert_ne!(e2, e1);
2290 }
2291
2292 #[test]
2293 fn untyped_subst_present() {
2294 let u = Unknown {
2295 name: "foo".into(),
2296 type_annotation: None,
2297 };
2298 let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2299 match r {
2300 Ok(e) => assert_eq!(e, Expr::val(1)),
2301 Err(empty) => match empty {},
2302 }
2303 }
2304
2305 #[test]
2306 fn untyped_subst_present_correct_type() {
2307 let u = Unknown {
2308 name: "foo".into(),
2309 type_annotation: Some(Type::Long),
2310 };
2311 let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2312 match r {
2313 Ok(e) => assert_eq!(e, Expr::val(1)),
2314 Err(empty) => match empty {},
2315 }
2316 }
2317
2318 #[test]
2319 fn untyped_subst_present_wrong_type() {
2320 let u = Unknown {
2321 name: "foo".into(),
2322 type_annotation: Some(Type::Bool),
2323 };
2324 let r = UntypedSubstitution::substitute(&u, Some(&Value::new(1, None)));
2325 match r {
2326 Ok(e) => assert_eq!(e, Expr::val(1)),
2327 Err(empty) => match empty {},
2328 }
2329 }
2330
2331 #[test]
2332 fn untyped_subst_not_present() {
2333 let u = Unknown {
2334 name: "foo".into(),
2335 type_annotation: Some(Type::Bool),
2336 };
2337 let r = UntypedSubstitution::substitute(&u, None);
2338 match r {
2339 Ok(n) => assert_eq!(n, Expr::unknown(u)),
2340 Err(empty) => match empty {},
2341 }
2342 }
2343
2344 #[test]
2345 fn typed_subst_present() {
2346 let u = Unknown {
2347 name: "foo".into(),
2348 type_annotation: None,
2349 };
2350 let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2351 assert_eq!(e, Expr::val(1));
2352 }
2353
2354 #[test]
2355 fn typed_subst_present_correct_type() {
2356 let u = Unknown {
2357 name: "foo".into(),
2358 type_annotation: Some(Type::Long),
2359 };
2360 let e = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap();
2361 assert_eq!(e, Expr::val(1));
2362 }
2363
2364 #[test]
2365 fn typed_subst_present_wrong_type() {
2366 let u = Unknown {
2367 name: "foo".into(),
2368 type_annotation: Some(Type::Bool),
2369 };
2370 let r = TypedSubstitution::substitute(&u, Some(&Value::new(1, None))).unwrap_err();
2371 assert_matches!(
2372 r,
2373 SubstitutionError::TypeError {
2374 expected: Type::Bool,
2375 actual: Type::Long,
2376 }
2377 );
2378 }
2379
2380 #[test]
2381 fn typed_subst_not_present() {
2382 let u = Unknown {
2383 name: "foo".into(),
2384 type_annotation: None,
2385 };
2386 let r = TypedSubstitution::substitute(&u, None).unwrap();
2387 assert_eq!(r, Expr::unknown(u));
2388 }
2389}