1use crate::{
16 floating_point_eq, imag,
17 instruction::MemoryReference,
18 parser::{lex, parse_expression, ParseError},
19 program::{disallow_leftover, ParseProgramError},
20 quil::Quil,
21 real,
22};
23use internment::ArcIntern;
24use lexical::{format, to_string_with_options, WriteFloatOptions};
25use nom_locate::LocatedSpan;
26use num_complex::Complex64;
27use once_cell::sync::Lazy;
28use std::{
29 borrow::Borrow,
30 collections::HashMap,
31 f64::consts::PI,
32 fmt,
33 hash::{Hash, Hasher},
34 num::NonZeroI32,
35 ops::{
36 Add, AddAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
37 },
38 str::FromStr,
39};
40
41#[cfg(test)]
42use proptest_derive::Arbitrary;
43
44#[cfg(not(feature = "python"))]
45use optipy::strip_pyo3;
46#[cfg(feature = "stubs")]
47use pyo3_stub_gen::derive::{
48 gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pyclass_enum, gen_stub_pymethods,
49};
50#[cfg(feature = "python")]
51pub(crate) mod quilpy;
52
53mod simplification;
54
55#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
57pub enum EvaluationError {
58 #[error("There wasn't enough information to completely evaluate the expression.")]
59 Incomplete,
60 #[error("The operation expected a real number but received a complex one.")]
61 NumberNotReal,
62 #[error("The operation expected a number but received a different type of expression.")]
63 NotANumber,
64}
65
66#[derive(Clone, Debug)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass_complex_enum)]
89#[cfg_attr(
90 feature = "python",
91 pyo3::pyclass(module = "quil.expression", eq, frozen, hash)
92)]
93#[cfg_attr(not(feature = "python"), strip_pyo3)]
94pub enum Expression {
95 Address(MemoryReference),
96 FunctionCall(FunctionCallExpression),
97 Infix(InfixExpression),
98 Number(Complex64),
99 #[pyo3(name = "Pi")]
104 PiConstant(),
105 Prefix(PrefixExpression),
106 Variable(String),
107}
108
109#[cfg(test)]
110impl proptest::prelude::Arbitrary for Expression {
111 type Parameters = ();
112 type Strategy = proptest::prelude::BoxedStrategy<Self>;
113
114 fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
115 use self::proptest_helpers::{arb_complex64, arb_expr_custom_leaves, arb_name};
116 use proptest::prelude::*;
117
118 let () = args;
119
120 arb_expr_custom_leaves(any::<MemoryReference>, arb_name, arb_complex64).boxed()
121 }
122}
123
124#[derive(Clone, Debug, PartialEq, Eq, Hash)]
132#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
133#[cfg_attr(
134 feature = "python",
135 pyo3::pyclass(module = "quil.expression", eq, frozen, hash, subclass)
136)]
137#[cfg_attr(not(feature = "python"), strip_pyo3)]
138pub struct FunctionCallExpression {
139 #[pyo3(get)]
140 pub function: ExpressionFunction,
141 pub expression: ArcIntern<Expression>,
142}
143
144impl FunctionCallExpression {
145 pub fn new(function: ExpressionFunction, expression: ArcIntern<Expression>) -> Self {
146 Self {
147 function,
148 expression,
149 }
150 }
151}
152
153#[derive(Clone, Debug, PartialEq, Eq, Hash)]
162#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
163#[cfg_attr(
164 feature = "python",
165 pyo3::pyclass(module = "quil.expression", eq, frozen, hash, subclass)
166)]
167#[cfg_attr(not(feature = "python"), strip_pyo3)]
168pub struct InfixExpression {
169 pub left: ArcIntern<Expression>,
170 #[pyo3(get)]
171 pub operator: InfixOperator,
172 pub right: ArcIntern<Expression>,
173}
174
175impl InfixExpression {
176 pub fn new(
177 left: ArcIntern<Expression>,
178 operator: InfixOperator,
179 right: ArcIntern<Expression>,
180 ) -> Self {
181 Self {
182 left,
183 operator,
184 right,
185 }
186 }
187}
188
189#[derive(Clone, Debug, PartialEq, Eq, Hash)]
197#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
198#[cfg_attr(
199 feature = "python",
200 pyo3::pyclass(module = "quil.expression", eq, frozen, hash, subclass)
201)]
202#[cfg_attr(not(feature = "python"), strip_pyo3)]
203pub struct PrefixExpression {
204 #[pyo3(get)]
205 pub operator: PrefixOperator,
206 pub expression: ArcIntern<Expression>,
207}
208
209impl PrefixExpression {
210 pub fn new(operator: PrefixOperator, expression: ArcIntern<Expression>) -> Self {
211 Self {
212 operator,
213 expression,
214 }
215 }
216}
217
218impl PartialEq for Expression {
220 fn eq(&self, other: &Self) -> bool {
222 match (self, other) {
223 (Self::Address(left), Self::Address(right)) => left == right,
224 (Self::Infix(left), Self::Infix(right)) => left == right,
225 (Self::Number(left), Self::Number(right)) => {
226 floating_point_eq::complex64::eq(*left, *right)
227 }
228 (Self::Prefix(left), Self::Prefix(right)) => left == right,
229 (Self::FunctionCall(left), Self::FunctionCall(right)) => left == right,
230 (Self::Variable(left), Self::Variable(right)) => left == right,
231 (Self::PiConstant(), Self::PiConstant()) => true,
232
233 (
236 Self::Address(_)
237 | Self::Infix(_)
238 | Self::Number(_)
239 | Self::Prefix(_)
240 | Self::FunctionCall(_)
241 | Self::Variable(_)
242 | Self::PiConstant(),
243 _,
244 ) => false,
245 }
246 }
247}
248
249impl Eq for Expression {}
251
252impl Hash for Expression {
253 fn hash<H: Hasher>(&self, state: &mut H) {
255 match self {
256 Self::Address(m) => {
257 "Address".hash(state);
258 m.hash(state);
259 }
260 Self::FunctionCall(FunctionCallExpression {
261 function,
262 expression,
263 }) => {
264 "FunctionCall".hash(state);
265 function.hash(state);
266 expression.hash(state);
267 }
268 Self::Infix(InfixExpression {
269 left,
270 operator,
271 right,
272 }) => {
273 "Infix".hash(state);
274 operator.hash(state);
275 left.hash(state);
276 right.hash(state);
277 }
278 Self::Number(n) => {
279 "Number".hash(state);
280 floating_point_eq::complex64::hash(*n, state);
281 }
282 Self::PiConstant() => {
283 "PiConstant()".hash(state);
284 }
285 Self::Prefix(p) => {
286 "Prefix".hash(state);
287 p.operator.hash(state);
288 p.expression.hash(state);
289 }
290 Self::Variable(v) => {
291 "Variable".hash(state);
292 v.hash(state);
293 }
294 }
295 }
296}
297
298macro_rules! impl_expr_op {
299 ($name:ident, $name_assign:ident, $function:ident, $function_assign:ident, $operator:ident) => {
300 impl $name for Expression {
301 type Output = Self;
302 fn $function(self, other: Self) -> Self {
303 Self::Infix(InfixExpression {
304 left: ArcIntern::new(self),
305 operator: InfixOperator::$operator,
306 right: ArcIntern::new(other),
307 })
308 }
309 }
310
311 impl $name<ArcIntern<Expression>> for Expression {
312 type Output = Self;
313 fn $function(self, other: ArcIntern<Expression>) -> Self {
314 Self::Infix(InfixExpression {
315 left: ArcIntern::new(self),
316 operator: InfixOperator::$operator,
317 right: other,
318 })
319 }
320 }
321
322 impl $name<Expression> for ArcIntern<Expression> {
323 type Output = Expression;
324 fn $function(self, other: Expression) -> Expression {
325 Expression::Infix(InfixExpression {
326 left: self,
327 operator: InfixOperator::$operator,
328 right: ArcIntern::new(other),
329 })
330 }
331 }
332
333 impl $name_assign for Expression {
334 fn $function_assign(&mut self, other: Self) {
335 let temp = ::std::mem::replace(self, Self::PiConstant());
337 *self = temp.$function(other);
338 }
339 }
340
341 impl $name_assign<ArcIntern<Expression>> for Expression {
342 fn $function_assign(&mut self, other: ArcIntern<Expression>) {
343 let temp = ::std::mem::replace(self, Self::PiConstant());
345 *self = temp.$function(other);
346 }
347 }
348 };
349}
350
351impl_expr_op!(BitXor, BitXorAssign, bitxor, bitxor_assign, Caret);
352impl_expr_op!(Add, AddAssign, add, add_assign, Plus);
353impl_expr_op!(Sub, SubAssign, sub, sub_assign, Minus);
354impl_expr_op!(Mul, MulAssign, mul, mul_assign, Star);
355impl_expr_op!(Div, DivAssign, div, div_assign, Slash);
356
357impl Neg for Expression {
358 type Output = Self;
359
360 fn neg(self) -> Self {
361 Expression::Prefix(PrefixExpression {
362 operator: PrefixOperator::Minus,
363 expression: ArcIntern::new(self),
364 })
365 }
366}
367
368#[inline]
370pub(crate) fn calculate_infix(
371 left: Complex64,
372 operator: InfixOperator,
373 right: Complex64,
374) -> Complex64 {
375 use InfixOperator::*;
376 match operator {
377 Caret => left.powc(right),
378 Plus => left + right,
379 Minus => left - right,
380 Slash => left / right,
381 Star => left * right,
382 }
383}
384
385#[inline]
387pub(crate) fn calculate_function(function: ExpressionFunction, argument: Complex64) -> Complex64 {
388 use ExpressionFunction::*;
389 match function {
390 Sine => argument.sin(),
391 Cis => argument.cos() + imag!(1f64) * argument.sin(),
392 Cosine => argument.cos(),
393 Exponent => argument.exp(),
394 SquareRoot => argument.sqrt(),
395 }
396}
397
398#[inline(always)]
400fn is_small(x: f64) -> bool {
401 x.abs() < 1e-16
402}
403
404impl Expression {
405 pub fn simplify(&mut self) {
420 match self {
421 Expression::Address(_) | Expression::Number(_) | Expression::Variable(_) => {}
422 Expression::PiConstant() => {
423 *self = Expression::Number(Complex64::from(PI));
424 }
425 _ => *self = simplification::run(self),
426 }
427 }
428
429 pub fn into_simplified(mut self) -> Self {
443 self.simplify();
444 self
445 }
446
447 pub fn evaluate<K1, K2>(
471 &self,
472 variables: &HashMap<K1, Complex64>,
473 memory_references: &HashMap<K2, Vec<f64>>,
474 ) -> Result<Complex64, EvaluationError>
475 where
476 K1: Borrow<str> + Hash + Eq,
477 K2: Borrow<str> + Hash + Eq,
478 {
479 use Expression::*;
480
481 match self {
482 FunctionCall(FunctionCallExpression {
483 function,
484 expression,
485 }) => {
486 let evaluated = expression.evaluate(variables, memory_references)?;
487 Ok(calculate_function(*function, evaluated))
488 }
489 Infix(InfixExpression {
490 left,
491 operator,
492 right,
493 }) => {
494 let left_evaluated = left.evaluate(variables, memory_references)?;
495 let right_evaluated = right.evaluate(variables, memory_references)?;
496 Ok(calculate_infix(left_evaluated, *operator, right_evaluated))
497 }
498 Prefix(PrefixExpression {
499 operator,
500 expression,
501 }) => {
502 use PrefixOperator::*;
503 let value = expression.evaluate(variables, memory_references)?;
504 if matches!(operator, Minus) {
505 Ok(-value)
506 } else {
507 Ok(value)
508 }
509 }
510 Variable(identifier) => match variables.get(identifier) {
511 Some(&value) => Ok(value),
512 None => Err(EvaluationError::Incomplete),
513 },
514 Address(memory_reference) => memory_references
515 .get(memory_reference.name.as_str())
516 .and_then(|values| {
517 let value = values.get(memory_reference.index as usize)?;
518 Some(real!(*value))
519 })
520 .ok_or(EvaluationError::Incomplete),
521 PiConstant() => Ok(real!(PI)),
522 Number(number) => Ok(*number),
523 }
524 }
525
526 #[must_use]
546 pub fn substitute_variables<K>(&self, variable_values: &HashMap<K, Expression>) -> Self
547 where
548 K: Borrow<str> + Hash + Eq,
549 {
550 use Expression::*;
551
552 match self {
553 FunctionCall(FunctionCallExpression {
554 function,
555 expression,
556 }) => Expression::FunctionCall(FunctionCallExpression {
557 function: *function,
558 expression: expression.substitute_variables(variable_values).into(),
559 }),
560 Infix(InfixExpression {
561 left,
562 operator,
563 right,
564 }) => {
565 let left = left.substitute_variables(variable_values).into();
566 let right = right.substitute_variables(variable_values).into();
567 Infix(InfixExpression {
568 left,
569 operator: *operator,
570 right,
571 })
572 }
573 Prefix(PrefixExpression {
574 operator,
575 expression,
576 }) => Prefix(PrefixExpression {
577 operator: *operator,
578 expression: expression.substitute_variables(variable_values).into(),
579 }),
580 Variable(identifier) => match variable_values.get(identifier) {
581 Some(value) => value.clone(),
582 None => Variable(identifier.clone()),
583 },
584 other => other.clone(),
585 }
586 }
587}
588
589#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
590#[cfg_attr(feature = "python", pyo3::pymethods)]
591impl Expression {
592 pub fn to_real(&self) -> Result<f64, EvaluationError> {
595 match self {
596 Expression::PiConstant() => Ok(PI),
597 Expression::Number(x) if is_small(x.im) => Ok(x.re),
598 Expression::Number(_) => Err(EvaluationError::NumberNotReal),
599 _ => Err(EvaluationError::NotANumber),
600 }
601 }
602}
603
604impl FromStr for Expression {
605 type Err = ParseProgramError<Self>;
606
607 fn from_str(s: &str) -> Result<Self, Self::Err> {
608 let input = LocatedSpan::new(s);
609 let tokens = lex(input)?;
610 disallow_leftover(parse_expression(&tokens).map_err(ParseError::from_nom_internal_err))
611 }
612}
613
614static FORMAT_REAL_OPTIONS: Lazy<WriteFloatOptions> = Lazy::new(|| {
615 WriteFloatOptions::builder()
616 .negative_exponent_break(NonZeroI32::new(-5))
617 .positive_exponent_break(NonZeroI32::new(15))
618 .trim_floats(true)
619 .build()
620 .expect("options are valid")
621});
622
623static FORMAT_IMAGINARY_OPTIONS: Lazy<WriteFloatOptions> = Lazy::new(|| {
624 WriteFloatOptions::builder()
625 .negative_exponent_break(NonZeroI32::new(-5))
626 .positive_exponent_break(NonZeroI32::new(15))
627 .trim_floats(false) .build()
629 .expect("options are valid")
630});
631
632#[inline(always)]
639pub(crate) fn format_complex(value: &Complex64) -> String {
640 const FORMAT: u128 = format::STANDARD;
641 if value.re == 0f64 && value.im == 0f64 {
642 "0".to_owned()
643 } else if value.im == 0f64 {
644 to_string_with_options::<_, FORMAT>(value.re, &FORMAT_REAL_OPTIONS)
645 } else if value.re == 0f64 {
646 to_string_with_options::<_, FORMAT>(value.im, &FORMAT_IMAGINARY_OPTIONS) + "i"
647 } else {
648 let mut out = to_string_with_options::<_, FORMAT>(value.re, &FORMAT_REAL_OPTIONS);
649 if value.im > 0f64 {
650 out.push('+')
651 }
652 out.push_str(&to_string_with_options::<_, FORMAT>(
653 value.im,
654 &FORMAT_IMAGINARY_OPTIONS,
655 ));
656 out.push('i');
657 out
658 }
659}
660
661impl Quil for Expression {
662 fn write(
663 &self,
664 f: &mut impl std::fmt::Write,
665 fall_back_to_debug: bool,
666 ) -> Result<(), crate::quil::ToQuilError> {
667 use Expression::*;
668 match self {
669 Address(memory_reference) => memory_reference.write(f, fall_back_to_debug),
670 FunctionCall(FunctionCallExpression {
671 function,
672 expression,
673 }) => {
674 write!(f, "{function}(")?;
675 expression.write(f, fall_back_to_debug)?;
676 write!(f, ")")?;
677 Ok(())
678 }
679 Infix(InfixExpression {
680 left,
681 operator,
682 right,
683 }) => {
684 format_inner_expression(f, fall_back_to_debug, left)?;
685 write!(f, "{operator}")?;
686 format_inner_expression(f, fall_back_to_debug, right)
687 }
688 Number(value) => write!(f, "{}", format_complex(value)).map_err(Into::into),
689 PiConstant() => write!(f, "pi").map_err(Into::into),
690 Prefix(PrefixExpression {
691 operator,
692 expression,
693 }) => {
694 write!(f, "{operator}")?;
695 format_inner_expression(f, fall_back_to_debug, expression)
696 }
697 Variable(identifier) => write!(f, "%{identifier}").map_err(Into::into),
698 }
699 }
700}
701
702fn format_inner_expression(
705 f: &mut impl std::fmt::Write,
706 fall_back_to_debug: bool,
707 expression: &Expression,
708) -> crate::quil::ToQuilResult<()> {
709 match expression {
710 Expression::Infix(InfixExpression {
711 left,
712 operator,
713 right,
714 }) => {
715 write!(f, "(")?;
716 format_inner_expression(f, fall_back_to_debug, left)?;
717 write!(f, "{operator}")?;
718 format_inner_expression(f, fall_back_to_debug, right)?;
719 write!(f, ")")?;
720 Ok(())
721 }
722 _ => expression.write(f, fall_back_to_debug),
723 }
724}
725
726#[cfg(test)]
727mod test {
728 use crate::{
729 expression::{
730 Expression, InfixExpression, InfixOperator, PrefixExpression, PrefixOperator,
731 },
732 quil::Quil,
733 real,
734 };
735
736 use internment::ArcIntern;
737
738 #[test]
739 fn formats_nested_expression() {
740 let expression = Expression::Infix(InfixExpression {
741 left: ArcIntern::new(Expression::Prefix(PrefixExpression {
742 operator: PrefixOperator::Minus,
743 expression: ArcIntern::new(Expression::Number(real!(3f64))),
744 })),
745 operator: InfixOperator::Star,
746 right: ArcIntern::new(Expression::Infix(InfixExpression {
747 left: ArcIntern::new(Expression::PiConstant()),
748 operator: InfixOperator::Slash,
749 right: ArcIntern::new(Expression::Number(real!(2f64))),
750 })),
751 });
752
753 assert_eq!(expression.to_quil_or_debug(), "-3*(pi/2)");
754 }
755}
756
757#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
759#[cfg_attr(feature = "stubs", gen_stub_pyclass_enum)]
760#[cfg_attr(
761 feature = "python",
762 pyo3::pyclass(
763 module = "quil.expression",
764 eq,
765 frozen,
766 hash,
767 str,
768 rename_all = "SCREAMING_SNAKE_CASE"
769 )
770)]
771#[cfg_attr(test, derive(Arbitrary))]
772pub enum ExpressionFunction {
773 Cis,
774 Cosine,
775 Exponent,
776 Sine,
777 SquareRoot,
778}
779
780impl fmt::Display for ExpressionFunction {
781 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
782 use ExpressionFunction::*;
783 write!(
784 f,
785 "{}",
786 match self {
787 Cis => "cis",
788 Cosine => "cos",
789 Exponent => "exp",
790 Sine => "sin",
791 SquareRoot => "sqrt",
792 }
793 )
794 }
795}
796
797#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
798#[cfg_attr(feature = "stubs", gen_stub_pyclass_enum)]
799#[cfg_attr(
800 feature = "python",
801 pyo3::pyclass(
802 module = "quil.expression",
803 eq,
804 frozen,
805 hash,
806 str,
807 rename_all = "SCREAMING_SNAKE_CASE"
808 )
809)]
810#[cfg_attr(test, derive(Arbitrary))]
811pub enum PrefixOperator {
812 Plus,
813 Minus,
814}
815
816impl fmt::Display for PrefixOperator {
817 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
818 use PrefixOperator::*;
819 write!(
820 f,
821 "{}",
822 match self {
823 Plus => "",
825 Minus => "-",
826 }
827 )
828 }
829}
830
831#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
832#[cfg_attr(feature = "stubs", gen_stub_pyclass_enum)]
833#[cfg_attr(
834 feature = "python",
835 pyo3::pyclass(
836 module = "quil.expression",
837 eq,
838 frozen,
839 hash,
840 str,
841 rename_all = "SCREAMING_SNAKE_CASE"
842 )
843)]
844#[cfg_attr(test, derive(Arbitrary))]
845pub enum InfixOperator {
846 Caret,
847 Plus,
848 Minus,
849 Slash,
850 Star,
851}
852
853impl fmt::Display for InfixOperator {
854 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
855 use InfixOperator::*;
856 write!(
857 f,
858 "{}",
859 match self {
860 Caret => "^",
861 Plus => "+",
862 Minus => " - ",
864 Slash => "/",
865 Star => "*",
866 }
867 )
868 }
869}
870
871pub mod interned {
874 use super::*;
875
876 macro_rules! atoms {
877 ($($func:ident: $ctor:ident($($typ:ty)?)),+ $(,)?) => {
878 $(
879 #[doc = concat!(
880 "A wrapper around [`Expression::",
881 stringify!($ctor),
882 "`] that returns an [`ArcIntern<Expression>`]."
883 )]
884 #[inline(always)]
885 pub fn $func($(value: $typ)?) -> ArcIntern<Expression> {
886 ArcIntern::new(Expression::$ctor($(std::convert::identity::<$typ>(value))?))
889 }
890 )+
891 };
892 }
893
894 macro_rules! expression_wrappers {
895 ($($func:ident: $atom:ident($ctor:ident { $($field:ident: $field_ty:ty),*$(,)? })),+ $(,)?) => {
896 paste::paste! { $(
897 #[doc = concat!(
898 "A wrapper around [`Expression::", stringify!([<$func:camel>]), "`] ",
899 "that takes the contents of the inner expression type as arguments directly ",
900 "and returns an [`ArcIntern<Expression>`].",
901 "\n\n",
902 "See also [`", stringify!($atom), "`].",
903 )]
904 #[inline(always)]
905 pub fn $func($($field: $field_ty),*) -> ArcIntern<Expression> {
906 $atom($ctor { $($field),* })
907 }
908 )+ }
909 };
910 }
911
912 macro_rules! function_wrappers {
913 ($($func:ident: $ctor:ident),+ $(,)?) => {
914 $(
915 #[doc = concat!(
916 "Create an <code>[ArcIntern]<[Expression]></code> representing ",
917 "`", stringify!($func), "(expression)`.",
918 "\n\n",
919 "A wrapper around [`Expression::FunctionCall`] with ",
920 "[`ExpressionFunction::", stringify!($ctor), "`].",
921 )]
922 #[inline(always)]
923 pub fn $func(expression: ArcIntern<Expression>) -> ArcIntern<Expression> {
924 function_call(ExpressionFunction::$ctor, expression)
925 }
926 )+
927 };
928 }
929
930 macro_rules! infix_wrappers {
931 ($($func:ident: $ctor:ident ($op:tt)),+ $(,)?) => {
932 $(
933 #[doc = concat!(
934 "Create an <code>[ArcIntern]<[Expression]></code> representing ",
935 "`left ", stringify!($op), " right`.",
936 "\n\n",
937 "A wrapper around [`Expression::Infix`] with ",
938 "[`InfixOperator::", stringify!($ctor), "`].",
939 )]
940 #[inline(always)]
941 pub fn $func(
942 left: ArcIntern<Expression>,
943 right: ArcIntern<Expression>,
944 ) -> ArcIntern<Expression> {
945 infix(left, InfixOperator::$ctor, right)
946 }
947 )+
948 };
949 }
950
951 macro_rules! prefix_wrappers {
952 ($($func:ident: $ctor:ident ($op:tt)),+ $(,)?) => {
953 $(
954 #[doc = concat!(
955 "Create an <code>[ArcIntern]<[Expression]></code> representing ",
956 "`", stringify!($op), "expression`.",
957 "\n\n",
958 "A wrapper around [`Expression::Prefix`] with ",
959 "[`PrefixOperator::", stringify!($ctor), "`].",
960 )]
961 #[inline(always)]
962 pub fn $func(expression: ArcIntern<Expression>) -> ArcIntern<Expression> {
963 prefix(PrefixOperator::$ctor, expression)
964 }
965 )+
966 };
967 }
968
969 atoms! {
970 address: Address(MemoryReference),
971 function_call_expr: FunctionCall(FunctionCallExpression),
972 infix_expr: Infix(InfixExpression),
973 pi: PiConstant(),
974 number: Number(Complex64),
975 prefix_expr: Prefix(PrefixExpression),
976 variable: Variable(String),
977 }
978
979 expression_wrappers! {
980 function_call: function_call_expr(FunctionCallExpression {
981 function: ExpressionFunction,
982 expression: ArcIntern<Expression>,
983 }),
984
985 infix: infix_expr(InfixExpression {
986 left: ArcIntern<Expression>,
987 operator: InfixOperator,
988 right: ArcIntern<Expression>,
989 }),
990
991 prefix: prefix_expr(PrefixExpression {
992 operator: PrefixOperator,
993 expression: ArcIntern<Expression>,
994 }),
995 }
996
997 function_wrappers! {
998 cis: Cis,
999 cos: Cosine,
1000 exp: Exponent,
1001 sin: Sine,
1002 sqrt: SquareRoot,
1003 }
1004
1005 infix_wrappers! {
1006 add: Plus (+),
1007 sub: Minus (-),
1008 mul: Star (*),
1009 div: Slash (/),
1010 pow: Caret (^),
1011 }
1012
1013 prefix_wrappers! {
1014 unary_plus: Plus (+),
1015 neg: Minus (-),
1016 }
1017}
1018
1019#[cfg(test)]
1020pub mod proptest_helpers {
1021 use super::*;
1022
1023 use std::f64::consts::TAU;
1024
1025 use proptest::prelude::*;
1026
1027 use crate::reserved::ReservedToken;
1028
1029 pub fn arb_name() -> impl Strategy<Value = String> {
1030 r"[A-Za-z_]([A-Za-z0-9_-]*[A-Za-z0-9_])?".prop_filter("Exclude reserved tokens", |t| {
1031 ReservedToken::from_str(t).is_err()
1032 })
1033 }
1034
1035 pub fn arb_complex64() -> impl Strategy<Value = Complex64> {
1037 ((-TAU..TAU), (-TAU..TAU)).prop_map(|(re, im)| Complex64 { re, im })
1038 }
1039
1040 pub fn arb_expr_nonzero(
1042 strat: impl Strategy<Value = Expression>,
1043 ) -> impl Strategy<Value = Expression> {
1044 strat.prop_filter("Exclude constantly-zero expressions", |expr| {
1045 expr.clone().into_simplified() != Expression::Number(Complex64::new(0.0, 0.0))
1046 })
1047 }
1048
1049 pub fn arb_expr_custom_leaves<
1051 MemRefStrat: Strategy<Value = MemoryReference> + 'static,
1052 VariableStrat: Strategy<Value = String> + 'static,
1053 ComplexStrat: Strategy<Value = Complex64> + 'static,
1054 >(
1055 mut arb_memory_reference: impl FnMut() -> MemRefStrat,
1056 mut arb_variable: impl FnMut() -> VariableStrat,
1057 mut arb_complex64: impl FnMut() -> ComplexStrat,
1058 ) -> impl Strategy<Value = Expression> {
1059 use Expression::*;
1060 let leaf = prop_oneof![
1061 arb_memory_reference().prop_map(Address),
1062 arb_complex64().prop_map(Number),
1063 Just(PiConstant()),
1064 arb_variable().prop_map(Variable),
1065 ];
1066 leaf.prop_recursive(
1067 4, 64, 16, |expr| {
1071 let inner = expr.clone();
1072 prop_oneof![
1073 (any::<ExpressionFunction>(), expr.clone()).prop_map(|(function, e)| {
1074 Expression::FunctionCall(FunctionCallExpression {
1075 function,
1076 expression: ArcIntern::new(e),
1077 })
1078 }),
1079 (expr.clone(), any::<InfixOperator>())
1080 .prop_flat_map(move |(left, operator)| {
1081 (
1082 Just(left),
1083 Just(operator),
1084 if let InfixOperator::Slash = operator {
1086 arb_expr_nonzero(inner.clone()).boxed()
1087 } else {
1088 inner.clone().boxed()
1089 },
1090 )
1091 })
1092 .prop_map(|(l, operator, r)| {
1093 Infix(InfixExpression {
1094 left: ArcIntern::new(l),
1095 operator,
1096 right: ArcIntern::new(r),
1097 })
1098 }),
1099 (any::<PrefixOperator>(), expr).prop_map(|(operator, e)| {
1100 Prefix(PrefixExpression {
1101 operator,
1102 expression: ArcIntern::new(e),
1103 })
1104 }),
1105 ]
1106 },
1107 )
1108 }
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113 use super::*;
1114
1115 use std::collections::{hash_map::DefaultHasher, HashSet};
1116
1117 use proptest::prelude::*;
1118
1119 use super::proptest_helpers::*;
1120
1121 #[inline]
1123 fn hash_to_u64<T: Hash>(t: &T) -> u64 {
1124 let mut s = DefaultHasher::new();
1125 t.hash(&mut s);
1126 s.finish()
1127 }
1128
1129 #[test]
1130 fn simplify_and_evaluate() {
1131 use Expression::*;
1132
1133 let one = real!(1.0);
1134 let empty_variables = HashMap::new();
1135
1136 let mut variables = HashMap::new();
1137 variables.insert("foo".to_owned(), real!(10f64));
1138 variables.insert("bar".to_owned(), real!(100f64));
1139
1140 let empty_memory = HashMap::new();
1141
1142 let mut memory_references = HashMap::new();
1143 memory_references.insert("theta", vec![1.0, 2.0]);
1144 memory_references.insert("beta", vec![3.0, 4.0]);
1145
1146 struct TestCase<'a> {
1147 expression: Expression,
1148 variables: &'a HashMap<String, Complex64>,
1149 memory_references: &'a HashMap<&'a str, Vec<f64>>,
1150 simplified: Expression,
1151 evaluated: Result<Complex64, EvaluationError>,
1152 }
1153
1154 let cases: Vec<TestCase> = vec![
1155 TestCase {
1156 expression: Number(one),
1157 variables: &empty_variables,
1158 memory_references: &empty_memory,
1159 simplified: Number(one),
1160 evaluated: Ok(one),
1161 },
1162 TestCase {
1163 expression: Expression::Prefix(PrefixExpression {
1164 operator: PrefixOperator::Minus,
1165 expression: ArcIntern::new(Number(real!(1f64))),
1166 }),
1167 variables: &empty_variables,
1168 memory_references: &empty_memory,
1169 simplified: Number(real!(-1f64)),
1170 evaluated: Ok(real!(-1f64)),
1171 },
1172 TestCase {
1173 expression: Expression::Variable("foo".to_owned()),
1174 variables: &variables,
1175 memory_references: &empty_memory,
1176 simplified: Expression::Variable("foo".to_owned()),
1177 evaluated: Ok(real!(10f64)),
1178 },
1179 TestCase {
1180 expression: Expression::from_str("%foo + %bar").unwrap(),
1181 variables: &variables,
1182 memory_references: &empty_memory,
1183 simplified: Expression::from_str("%foo + %bar").unwrap(),
1184 evaluated: Ok(real!(110f64)),
1185 },
1186 TestCase {
1187 expression: Expression::FunctionCall(FunctionCallExpression {
1188 function: ExpressionFunction::Sine,
1189 expression: ArcIntern::new(Expression::Number(real!(PI / 2f64))),
1190 }),
1191 variables: &variables,
1192 memory_references: &empty_memory,
1193 simplified: Number(real!(1f64)),
1194 evaluated: Ok(real!(1f64)),
1195 },
1196 TestCase {
1197 expression: Expression::from_str("theta[1] * beta[0]").unwrap(),
1198 variables: &empty_variables,
1199 memory_references: &memory_references,
1200 simplified: Expression::from_str("theta[1] * beta[0]").unwrap(),
1201 evaluated: Ok(real!(6.0)),
1202 },
1203 ];
1204
1205 for mut case in cases {
1206 let evaluated = case
1207 .expression
1208 .evaluate(case.variables, case.memory_references);
1209 assert_eq!(evaluated, case.evaluated);
1210
1211 case.expression.simplify();
1212 assert_eq!(case.expression, case.simplified);
1213 }
1214 }
1215
1216 fn parenthesized(expression: &Expression) -> String {
1218 use Expression::*;
1219 match expression {
1220 Address(memory_reference) => memory_reference.to_quil_or_debug(),
1221 FunctionCall(FunctionCallExpression {
1222 function,
1223 expression,
1224 }) => format!("({function}({}))", parenthesized(expression)),
1225 Infix(InfixExpression {
1226 left,
1227 operator,
1228 right,
1229 }) => format!(
1230 "({}{}{})",
1231 parenthesized(left),
1232 operator,
1233 parenthesized(right)
1234 ),
1235 Number(value) => format!("({})", format_complex(value)),
1236 PiConstant() => "pi".to_string(),
1237 Prefix(PrefixExpression {
1238 operator,
1239 expression,
1240 }) => format!("({}{})", operator, parenthesized(expression)),
1241 Variable(identifier) => format!("(%{identifier})"),
1242 }
1243 }
1244
1245 proptest! {
1246 #[test]
1247 fn eq(a: f64, b: f64) {
1248 let first = Expression::Infix (InfixExpression {
1249 left: ArcIntern::new(Expression::Number(real!(a))),
1250 operator: InfixOperator::Plus,
1251 right: ArcIntern::new(Expression::Number(real!(b))),
1252 } );
1253 let differing = Expression::Number(real!(a + b));
1254 prop_assert_eq!(&first, &first);
1255 prop_assert_ne!(&first, &differing);
1256 }
1257
1258 #[test]
1259 fn hash(a: f64, b: f64) {
1260 let first = Expression::Infix (InfixExpression {
1261 left: ArcIntern::new(Expression::Number(real!(a))),
1262 operator: InfixOperator::Plus,
1263 right: ArcIntern::new(Expression::Number(real!(b))),
1264 });
1265 let matching = first.clone();
1266 let differing = Expression::Number(real!(a + b));
1267 let mut set = HashSet::new();
1268 set.insert(first);
1269 assert!(set.contains(&matching));
1270 assert!(!set.contains(&differing))
1271 }
1272
1273 #[test]
1274 fn eq_iff_hash_eq(x: Expression, y: Expression) {
1275 prop_assert_eq!(x == y, hash_to_u64(&x) == hash_to_u64(&y));
1276 }
1277
1278 #[test]
1279 fn reals_are_real(x: f64) {
1280 prop_assert_eq!(Expression::Number(real!(x)).to_real(), Ok(x))
1281 }
1282
1283 #[test]
1284 fn some_nums_are_real(re: f64, im: f64) {
1285 let result = Expression::Number(Complex64{re, im}).to_real();
1286 if is_small(im) {
1287 prop_assert_eq!(result, Ok(re))
1288 } else {
1289 prop_assert_eq!(result, Err(EvaluationError::NumberNotReal))
1290 }
1291 }
1292
1293 #[test]
1294 fn no_other_exps_are_real(expr in any::<Expression>().prop_filter("Not numbers", |e| !matches!(e, Expression::Number(_) | Expression::PiConstant()))) {
1295 prop_assert_eq!(expr.to_real(), Err(EvaluationError::NotANumber))
1296 }
1297
1298 #[test]
1299 fn complexes_are_parseable_as_expressions(value in arb_complex64()) {
1300 let parsed = Expression::from_str(&format_complex(&value));
1301 assert!(parsed.is_ok());
1302 let simple = parsed.unwrap().into_simplified();
1303 assert_eq!(Expression::Number(value), simple);
1304 }
1305
1306 #[test]
1307 fn exponentiation_works_as_expected(left: Expression, right: Expression) {
1308 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Caret, right: ArcIntern::new(right.clone()) } );
1309 prop_assert_eq!(left ^ right, expected);
1310 }
1311
1312 #[test]
1313 fn in_place_exponentiation_works_as_expected(left: Expression, right: Expression) {
1314 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Caret, right: ArcIntern::new(right.clone()) } );
1315 let mut x = left;
1316 x ^= right;
1317 prop_assert_eq!(x, expected);
1318 }
1319
1320 #[test]
1321 fn addition_works_as_expected(left: Expression, right: Expression) {
1322 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Plus, right: ArcIntern::new(right.clone()) } );
1323 prop_assert_eq!(left + right, expected);
1324 }
1325
1326 #[test]
1327 fn in_place_addition_works_as_expected(left: Expression, right: Expression) {
1328 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Plus, right: ArcIntern::new(right.clone()) } );
1329 let mut x = left;
1330 x += right;
1331 prop_assert_eq!(x, expected);
1332 }
1333
1334 #[test]
1335 fn subtraction_works_as_expected(left: Expression, right: Expression) {
1336 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Minus, right: ArcIntern::new(right.clone()) } );
1337 prop_assert_eq!(left - right, expected);
1338 }
1339
1340 #[test]
1341 fn in_place_subtraction_works_as_expected(left: Expression, right: Expression) {
1342 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Minus, right: ArcIntern::new(right.clone()) } );
1343 let mut x = left;
1344 x -= right;
1345 prop_assert_eq!(x, expected);
1346 }
1347
1348 #[test]
1349 fn multiplication_works_as_expected(left: Expression, right: Expression) {
1350 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Star, right: ArcIntern::new(right.clone()) } );
1351 prop_assert_eq!(left * right, expected);
1352 }
1353
1354 #[test]
1355 fn in_place_multiplication_works_as_expected(left: Expression, right: Expression) {
1356 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Star, right: ArcIntern::new(right.clone()) } );
1357 let mut x = left;
1358 x *= right;
1359 prop_assert_eq!(x, expected);
1360 }
1361
1362
1363 #[test]
1365 fn division_works_as_expected(left: Expression, right in arb_expr_nonzero(any::<Expression>())) {
1366 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Slash, right: ArcIntern::new(right.clone()) } );
1367 prop_assert_eq!(left / right, expected);
1368 }
1369
1370 #[test]
1372 fn in_place_division_works_as_expected(left: Expression, right in arb_expr_nonzero(any::<Expression>())) {
1373 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Slash, right: ArcIntern::new(right.clone()) } );
1374 let mut x = left;
1375 x /= right;
1376 prop_assert_eq!(x, expected);
1377 }
1378
1379 #[allow(clippy::redundant_clone)]
1381 #[test]
1382 fn round_trip(e: Expression) {
1383 let simple_e = e.clone().into_simplified();
1384 let s = parenthesized(&e);
1385 let p = Expression::from_str(&s);
1386 prop_assert!(p.is_ok());
1387 let p = p.unwrap();
1388 let simple_p = p.clone().into_simplified();
1389
1390 prop_assert_eq!(
1391 &simple_p,
1392 &simple_e,
1393 "Simplified expressions should be equal:\nparenthesized {p} ({p:?}) extracted from {s} simplified to {simple_p}\nvs original {e} ({e:?}) simplified to {simple_e}",
1394 p=p.to_quil_or_debug(),
1395 s=s,
1396 e=e.to_quil_or_debug(),
1397 simple_p=simple_p.to_quil_or_debug(),
1398 simple_e=simple_e.to_quil_or_debug()
1399 );
1400 }
1401
1402 }
1403
1404 #[test]
1407 fn specific_round_trip_tests() {
1408 for input in &[
1409 "-1*(phases[0]+phases[1])",
1410 "(-1*(phases[0]+phases[1]))+(-1*(phases[0]+phases[1]))",
1411 ] {
1412 let parsed = Expression::from_str(input);
1413 let parsed = parsed.unwrap();
1414 let restring = parsed.to_quil_or_debug();
1415 assert_eq!(input, &restring);
1416 }
1417 }
1418
1419 #[test]
1420 fn test_nan_is_equal() {
1421 let left = Expression::Number(f64::NAN.into());
1422 let right = left.clone();
1423 assert_eq!(left, right);
1424 }
1425
1426 #[test]
1427 fn specific_simplification_tests() {
1428 for (input, expected) in [
1429 ("pi", Expression::Number(PI.into())),
1430 ("pi/2", Expression::Number((PI / 2.0).into())),
1431 ("pi * pi", Expression::Number((PI.powi(2)).into())),
1432 ("1.0/(1.0-1.0)", Expression::Number(f64::NAN.into())),
1433 (
1434 "(a[0]*2*pi)/6.283185307179586",
1435 Expression::Address(MemoryReference {
1436 name: String::from("a"),
1437 index: 0,
1438 }),
1439 ),
1440 ] {
1441 assert_eq!(
1442 Expression::from_str(input).unwrap().into_simplified(),
1443 expected
1444 )
1445 }
1446 }
1447
1448 #[test]
1449 fn specific_to_real_tests() {
1450 for (input, expected) in [
1451 (Expression::PiConstant(), Ok(PI)),
1452 (Expression::Number(Complex64 { re: 1.0, im: 0.0 }), Ok(1.0)),
1453 (
1454 Expression::Number(Complex64 { re: 1.0, im: 1.0 }),
1455 Err(EvaluationError::NumberNotReal),
1456 ),
1457 (
1458 Expression::Variable("Not a number".into()),
1459 Err(EvaluationError::NotANumber),
1460 ),
1461 ] {
1462 assert_eq!(input.to_real(), expected)
1463 }
1464 }
1465
1466 #[test]
1467 fn specific_format_complex_tests() {
1468 for (x, s) in &[
1469 (Complex64::new(0.0, 0.0), "0"),
1470 (Complex64::new(-0.0, 0.0), "0"),
1471 (Complex64::new(-0.0, -0.0), "0"),
1472 (Complex64::new(0.0, 1.0), "1.0i"),
1473 (Complex64::new(1.0, -1.0), "1-1.0i"),
1474 (Complex64::new(1.234, 0.0), "1.234"),
1475 (Complex64::new(0.0, 1.234), "1.234i"),
1476 (Complex64::new(-1.234, 0.0), "-1.234"),
1477 (Complex64::new(0.0, -1.234), "-1.234i"),
1478 (Complex64::new(1.234, 5.678), "1.234+5.678i"),
1479 (Complex64::new(-1.234, 5.678), "-1.234+5.678i"),
1480 (Complex64::new(1.234, -5.678), "1.234-5.678i"),
1481 (Complex64::new(-1.234, -5.678), "-1.234-5.678i"),
1482 (Complex64::new(1e100, 2e-100), "1e100+2.0e-100i"),
1483 ] {
1484 assert_eq!(format_complex(x), *s);
1485 }
1486 }
1487}