1use crate::{
16 floating_point_eq, imag,
17 instruction::MemoryReference,
18 parser::{lex, parse_expression, ParseError},
19 program::{disallow_leftover, ParseProgramError},
20 quil::Quil,
21 real,
22};
23use internment::ArcIntern;
24use lexical::{format, to_string_with_options, WriteFloatOptions};
25use nom_locate::LocatedSpan;
26use num_complex::Complex64;
27use once_cell::sync::Lazy;
28use std::{
29 borrow::Borrow,
30 collections::HashMap,
31 f64::consts::PI,
32 fmt,
33 hash::{Hash, Hasher},
34 num::NonZeroI32,
35 ops::{
36 Add, AddAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
37 },
38 str::FromStr,
39};
40
41#[cfg(test)]
42use proptest_derive::Arbitrary;
43
44#[cfg(not(feature = "python"))]
45use optipy::strip_pyo3;
46#[cfg(feature = "stubs")]
47use pyo3_stub_gen::derive::{
48 gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pyclass_enum, gen_stub_pymethods,
49};
50#[cfg(feature = "python")]
51pub(crate) mod quilpy;
52
53mod simplification;
54
55#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
57pub enum EvaluationError {
58 #[error("There wasn't enough information to completely evaluate the expression.")]
59 Incomplete,
60 #[error("The operation expected a real number but received a complex one.")]
61 NumberNotReal,
62 #[error("The operation expected a number but received a different type of expression.")]
63 NotANumber,
64}
65
66#[derive(Clone, Debug)]
88#[cfg_attr(feature = "stubs", gen_stub_pyclass_complex_enum)]
89#[cfg_attr(
90 feature = "python",
91 pyo3::pyclass(module = "quil.expression", eq, frozen, hash)
92)]
93#[cfg_attr(not(feature = "python"), strip_pyo3)]
94pub enum Expression {
95 Address(MemoryReference),
96 FunctionCall(FunctionCallExpression),
97 Infix(InfixExpression),
98 Number(Complex64),
99 #[pyo3(name = "Pi")]
104 PiConstant(),
105 Prefix(PrefixExpression),
106 Variable(String),
107}
108
109#[derive(Clone, Debug, PartialEq, Eq, Hash)]
117#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
118#[cfg_attr(
119 feature = "python",
120 pyo3::pyclass(module = "quil.expression", eq, frozen, hash, subclass)
121)]
122#[cfg_attr(not(feature = "python"), strip_pyo3)]
123pub struct FunctionCallExpression {
124 #[pyo3(get)]
125 pub function: ExpressionFunction,
126 pub expression: ArcIntern<Expression>,
127}
128
129impl FunctionCallExpression {
130 pub fn new(function: ExpressionFunction, expression: ArcIntern<Expression>) -> Self {
131 Self {
132 function,
133 expression,
134 }
135 }
136}
137
138#[derive(Clone, Debug, PartialEq, Eq, Hash)]
147#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
148#[cfg_attr(
149 feature = "python",
150 pyo3::pyclass(module = "quil.expression", eq, frozen, hash, subclass)
151)]
152#[cfg_attr(not(feature = "python"), strip_pyo3)]
153pub struct InfixExpression {
154 pub left: ArcIntern<Expression>,
155 #[pyo3(get)]
156 pub operator: InfixOperator,
157 pub right: ArcIntern<Expression>,
158}
159
160impl InfixExpression {
161 pub fn new(
162 left: ArcIntern<Expression>,
163 operator: InfixOperator,
164 right: ArcIntern<Expression>,
165 ) -> Self {
166 Self {
167 left,
168 operator,
169 right,
170 }
171 }
172}
173
174#[derive(Clone, Debug, PartialEq, Eq, Hash)]
182#[cfg_attr(feature = "stubs", gen_stub_pyclass)]
183#[cfg_attr(
184 feature = "python",
185 pyo3::pyclass(module = "quil.expression", eq, frozen, hash, subclass)
186)]
187#[cfg_attr(not(feature = "python"), strip_pyo3)]
188pub struct PrefixExpression {
189 #[pyo3(get)]
190 pub operator: PrefixOperator,
191 pub expression: ArcIntern<Expression>,
192}
193
194impl PrefixExpression {
195 pub fn new(operator: PrefixOperator, expression: ArcIntern<Expression>) -> Self {
196 Self {
197 operator,
198 expression,
199 }
200 }
201}
202
203impl PartialEq for Expression {
205 fn eq(&self, other: &Self) -> bool {
207 match (self, other) {
208 (Self::Address(left), Self::Address(right)) => left == right,
209 (Self::Infix(left), Self::Infix(right)) => left == right,
210 (Self::Number(left), Self::Number(right)) => {
211 floating_point_eq::complex64::eq(*left, *right)
212 }
213 (Self::Prefix(left), Self::Prefix(right)) => left == right,
214 (Self::FunctionCall(left), Self::FunctionCall(right)) => left == right,
215 (Self::Variable(left), Self::Variable(right)) => left == right,
216 (Self::PiConstant(), Self::PiConstant()) => true,
217
218 (
221 Self::Address(_)
222 | Self::Infix(_)
223 | Self::Number(_)
224 | Self::Prefix(_)
225 | Self::FunctionCall(_)
226 | Self::Variable(_)
227 | Self::PiConstant(),
228 _,
229 ) => false,
230 }
231 }
232}
233
234impl Eq for Expression {}
236
237impl Hash for Expression {
238 fn hash<H: Hasher>(&self, state: &mut H) {
240 match self {
241 Self::Address(m) => {
242 "Address".hash(state);
243 m.hash(state);
244 }
245 Self::FunctionCall(FunctionCallExpression {
246 function,
247 expression,
248 }) => {
249 "FunctionCall".hash(state);
250 function.hash(state);
251 expression.hash(state);
252 }
253 Self::Infix(InfixExpression {
254 left,
255 operator,
256 right,
257 }) => {
258 "Infix".hash(state);
259 operator.hash(state);
260 left.hash(state);
261 right.hash(state);
262 }
263 Self::Number(n) => {
264 "Number".hash(state);
265 floating_point_eq::complex64::hash(*n, state);
266 }
267 Self::PiConstant() => {
268 "PiConstant()".hash(state);
269 }
270 Self::Prefix(p) => {
271 "Prefix".hash(state);
272 p.operator.hash(state);
273 p.expression.hash(state);
274 }
275 Self::Variable(v) => {
276 "Variable".hash(state);
277 v.hash(state);
278 }
279 }
280 }
281}
282
283macro_rules! impl_expr_op {
284 ($name:ident, $name_assign:ident, $function:ident, $function_assign:ident, $operator:ident) => {
285 impl $name for Expression {
286 type Output = Self;
287 fn $function(self, other: Self) -> Self {
288 Self::Infix(InfixExpression {
289 left: ArcIntern::new(self),
290 operator: InfixOperator::$operator,
291 right: ArcIntern::new(other),
292 })
293 }
294 }
295
296 impl $name<ArcIntern<Expression>> for Expression {
297 type Output = Self;
298 fn $function(self, other: ArcIntern<Expression>) -> Self {
299 Self::Infix(InfixExpression {
300 left: ArcIntern::new(self),
301 operator: InfixOperator::$operator,
302 right: other,
303 })
304 }
305 }
306
307 impl $name<Expression> for ArcIntern<Expression> {
308 type Output = Expression;
309 fn $function(self, other: Expression) -> Expression {
310 Expression::Infix(InfixExpression {
311 left: self,
312 operator: InfixOperator::$operator,
313 right: ArcIntern::new(other),
314 })
315 }
316 }
317
318 impl $name_assign for Expression {
319 fn $function_assign(&mut self, other: Self) {
320 let temp = ::std::mem::replace(self, Self::PiConstant());
322 *self = temp.$function(other);
323 }
324 }
325
326 impl $name_assign<ArcIntern<Expression>> for Expression {
327 fn $function_assign(&mut self, other: ArcIntern<Expression>) {
328 let temp = ::std::mem::replace(self, Self::PiConstant());
330 *self = temp.$function(other);
331 }
332 }
333 };
334}
335
336impl_expr_op!(BitXor, BitXorAssign, bitxor, bitxor_assign, Caret);
337impl_expr_op!(Add, AddAssign, add, add_assign, Plus);
338impl_expr_op!(Sub, SubAssign, sub, sub_assign, Minus);
339impl_expr_op!(Mul, MulAssign, mul, mul_assign, Star);
340impl_expr_op!(Div, DivAssign, div, div_assign, Slash);
341
342impl Neg for Expression {
343 type Output = Self;
344
345 fn neg(self) -> Self {
346 Expression::Prefix(PrefixExpression {
347 operator: PrefixOperator::Minus,
348 expression: ArcIntern::new(self),
349 })
350 }
351}
352
353#[inline]
355pub(crate) fn calculate_infix(
356 left: Complex64,
357 operator: InfixOperator,
358 right: Complex64,
359) -> Complex64 {
360 use InfixOperator::*;
361 match operator {
362 Caret => left.powc(right),
363 Plus => left + right,
364 Minus => left - right,
365 Slash => left / right,
366 Star => left * right,
367 }
368}
369
370#[inline]
372pub(crate) fn calculate_function(function: ExpressionFunction, argument: Complex64) -> Complex64 {
373 use ExpressionFunction::*;
374 match function {
375 Sine => argument.sin(),
376 Cis => argument.cos() + imag!(1f64) * argument.sin(),
377 Cosine => argument.cos(),
378 Exponent => argument.exp(),
379 SquareRoot => argument.sqrt(),
380 }
381}
382
383#[inline(always)]
385fn is_small(x: f64) -> bool {
386 x.abs() < 1e-16
387}
388
389impl Expression {
390 pub fn simplify(&mut self) {
405 match self {
406 Expression::Address(_) | Expression::Number(_) | Expression::Variable(_) => {}
407 Expression::PiConstant() => {
408 *self = Expression::Number(Complex64::from(PI));
409 }
410 _ => *self = simplification::run(self),
411 }
412 }
413
414 pub fn into_simplified(mut self) -> Self {
428 self.simplify();
429 self
430 }
431
432 pub fn evaluate<K1, K2>(
456 &self,
457 variables: &HashMap<K1, Complex64>,
458 memory_references: &HashMap<K2, Vec<f64>>,
459 ) -> Result<Complex64, EvaluationError>
460 where
461 K1: Borrow<str> + Hash + Eq,
462 K2: Borrow<str> + Hash + Eq,
463 {
464 use Expression::*;
465
466 match self {
467 FunctionCall(FunctionCallExpression {
468 function,
469 expression,
470 }) => {
471 let evaluated = expression.evaluate(variables, memory_references)?;
472 Ok(calculate_function(*function, evaluated))
473 }
474 Infix(InfixExpression {
475 left,
476 operator,
477 right,
478 }) => {
479 let left_evaluated = left.evaluate(variables, memory_references)?;
480 let right_evaluated = right.evaluate(variables, memory_references)?;
481 Ok(calculate_infix(left_evaluated, *operator, right_evaluated))
482 }
483 Prefix(PrefixExpression {
484 operator,
485 expression,
486 }) => {
487 use PrefixOperator::*;
488 let value = expression.evaluate(variables, memory_references)?;
489 if matches!(operator, Minus) {
490 Ok(-value)
491 } else {
492 Ok(value)
493 }
494 }
495 Variable(identifier) => match variables.get(identifier) {
496 Some(&value) => Ok(value),
497 None => Err(EvaluationError::Incomplete),
498 },
499 Address(memory_reference) => memory_references
500 .get(memory_reference.name.as_str())
501 .and_then(|values| {
502 let value = values.get(memory_reference.index as usize)?;
503 Some(real!(*value))
504 })
505 .ok_or(EvaluationError::Incomplete),
506 PiConstant() => Ok(real!(PI)),
507 Number(number) => Ok(*number),
508 }
509 }
510
511 #[must_use]
531 pub fn substitute_variables<K>(&self, variable_values: &HashMap<K, Expression>) -> Self
532 where
533 K: Borrow<str> + Hash + Eq,
534 {
535 use Expression::*;
536
537 match self {
538 FunctionCall(FunctionCallExpression {
539 function,
540 expression,
541 }) => Expression::FunctionCall(FunctionCallExpression {
542 function: *function,
543 expression: expression.substitute_variables(variable_values).into(),
544 }),
545 Infix(InfixExpression {
546 left,
547 operator,
548 right,
549 }) => {
550 let left = left.substitute_variables(variable_values).into();
551 let right = right.substitute_variables(variable_values).into();
552 Infix(InfixExpression {
553 left,
554 operator: *operator,
555 right,
556 })
557 }
558 Prefix(PrefixExpression {
559 operator,
560 expression,
561 }) => Prefix(PrefixExpression {
562 operator: *operator,
563 expression: expression.substitute_variables(variable_values).into(),
564 }),
565 Variable(identifier) => match variable_values.get(identifier) {
566 Some(value) => value.clone(),
567 None => Variable(identifier.clone()),
568 },
569 other => other.clone(),
570 }
571 }
572}
573
574#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
575#[cfg_attr(feature = "python", pyo3::pymethods)]
576impl Expression {
577 pub fn to_real(&self) -> Result<f64, EvaluationError> {
580 match self {
581 Expression::PiConstant() => Ok(PI),
582 Expression::Number(x) if is_small(x.im) => Ok(x.re),
583 Expression::Number(_) => Err(EvaluationError::NumberNotReal),
584 _ => Err(EvaluationError::NotANumber),
585 }
586 }
587}
588
589impl FromStr for Expression {
590 type Err = ParseProgramError<Self>;
591
592 fn from_str(s: &str) -> Result<Self, Self::Err> {
593 let input = LocatedSpan::new(s);
594 let tokens = lex(input)?;
595 disallow_leftover(parse_expression(&tokens).map_err(ParseError::from_nom_internal_err))
596 }
597}
598
599static FORMAT_REAL_OPTIONS: Lazy<WriteFloatOptions> = Lazy::new(|| {
600 WriteFloatOptions::builder()
601 .negative_exponent_break(NonZeroI32::new(-5))
602 .positive_exponent_break(NonZeroI32::new(15))
603 .trim_floats(true)
604 .build()
605 .expect("options are valid")
606});
607
608static FORMAT_IMAGINARY_OPTIONS: Lazy<WriteFloatOptions> = Lazy::new(|| {
609 WriteFloatOptions::builder()
610 .negative_exponent_break(NonZeroI32::new(-5))
611 .positive_exponent_break(NonZeroI32::new(15))
612 .trim_floats(false) .build()
614 .expect("options are valid")
615});
616
617#[inline(always)]
624pub(crate) fn format_complex(value: &Complex64) -> String {
625 const FORMAT: u128 = format::STANDARD;
626 if value.re == 0f64 && value.im == 0f64 {
627 "0".to_owned()
628 } else if value.im == 0f64 {
629 to_string_with_options::<_, FORMAT>(value.re, &FORMAT_REAL_OPTIONS)
630 } else if value.re == 0f64 {
631 to_string_with_options::<_, FORMAT>(value.im, &FORMAT_IMAGINARY_OPTIONS) + "i"
632 } else {
633 let mut out = to_string_with_options::<_, FORMAT>(value.re, &FORMAT_REAL_OPTIONS);
634 if value.im > 0f64 {
635 out.push('+')
636 }
637 out.push_str(&to_string_with_options::<_, FORMAT>(
638 value.im,
639 &FORMAT_IMAGINARY_OPTIONS,
640 ));
641 out.push('i');
642 out
643 }
644}
645
646impl Quil for Expression {
647 fn write(
648 &self,
649 f: &mut impl std::fmt::Write,
650 fall_back_to_debug: bool,
651 ) -> Result<(), crate::quil::ToQuilError> {
652 use Expression::*;
653 match self {
654 Address(memory_reference) => memory_reference.write(f, fall_back_to_debug),
655 FunctionCall(FunctionCallExpression {
656 function,
657 expression,
658 }) => {
659 write!(f, "{function}(")?;
660 expression.write(f, fall_back_to_debug)?;
661 write!(f, ")")?;
662 Ok(())
663 }
664 Infix(InfixExpression {
665 left,
666 operator,
667 right,
668 }) => {
669 format_inner_expression(f, fall_back_to_debug, left)?;
670 write!(f, "{operator}")?;
671 format_inner_expression(f, fall_back_to_debug, right)
672 }
673 Number(value) => write!(f, "{}", format_complex(value)).map_err(Into::into),
674 PiConstant() => write!(f, "pi").map_err(Into::into),
675 Prefix(PrefixExpression {
676 operator,
677 expression,
678 }) => {
679 write!(f, "{operator}")?;
680 format_inner_expression(f, fall_back_to_debug, expression)
681 }
682 Variable(identifier) => write!(f, "%{identifier}").map_err(Into::into),
683 }
684 }
685}
686
687fn format_inner_expression(
690 f: &mut impl std::fmt::Write,
691 fall_back_to_debug: bool,
692 expression: &Expression,
693) -> crate::quil::ToQuilResult<()> {
694 match expression {
695 Expression::Infix(InfixExpression {
696 left,
697 operator,
698 right,
699 }) => {
700 write!(f, "(")?;
701 format_inner_expression(f, fall_back_to_debug, left)?;
702 write!(f, "{operator}")?;
703 format_inner_expression(f, fall_back_to_debug, right)?;
704 write!(f, ")")?;
705 Ok(())
706 }
707 _ => expression.write(f, fall_back_to_debug),
708 }
709}
710
711#[cfg(test)]
712mod test {
713 use crate::{
714 expression::{
715 Expression, InfixExpression, InfixOperator, PrefixExpression, PrefixOperator,
716 },
717 quil::Quil,
718 real,
719 };
720
721 use internment::ArcIntern;
722
723 #[test]
724 fn formats_nested_expression() {
725 let expression = Expression::Infix(InfixExpression {
726 left: ArcIntern::new(Expression::Prefix(PrefixExpression {
727 operator: PrefixOperator::Minus,
728 expression: ArcIntern::new(Expression::Number(real!(3f64))),
729 })),
730 operator: InfixOperator::Star,
731 right: ArcIntern::new(Expression::Infix(InfixExpression {
732 left: ArcIntern::new(Expression::PiConstant()),
733 operator: InfixOperator::Slash,
734 right: ArcIntern::new(Expression::Number(real!(2f64))),
735 })),
736 });
737
738 assert_eq!(expression.to_quil_or_debug(), "-3*(pi/2)");
739 }
740}
741
742#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
744#[cfg_attr(feature = "stubs", gen_stub_pyclass_enum)]
745#[cfg_attr(
746 feature = "python",
747 pyo3::pyclass(
748 module = "quil.expression",
749 eq,
750 frozen,
751 hash,
752 str,
753 rename_all = "SCREAMING_SNAKE_CASE"
754 )
755)]
756#[cfg_attr(test, derive(Arbitrary))]
757pub enum ExpressionFunction {
758 Cis,
759 Cosine,
760 Exponent,
761 Sine,
762 SquareRoot,
763}
764
765impl fmt::Display for ExpressionFunction {
766 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
767 use ExpressionFunction::*;
768 write!(
769 f,
770 "{}",
771 match self {
772 Cis => "cis",
773 Cosine => "cos",
774 Exponent => "exp",
775 Sine => "sin",
776 SquareRoot => "sqrt",
777 }
778 )
779 }
780}
781
782#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
783#[cfg_attr(feature = "stubs", gen_stub_pyclass_enum)]
784#[cfg_attr(
785 feature = "python",
786 pyo3::pyclass(
787 module = "quil.expression",
788 eq,
789 frozen,
790 hash,
791 str,
792 rename_all = "SCREAMING_SNAKE_CASE"
793 )
794)]
795#[cfg_attr(test, derive(Arbitrary))]
796pub enum PrefixOperator {
797 Plus,
798 Minus,
799}
800
801impl fmt::Display for PrefixOperator {
802 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
803 use PrefixOperator::*;
804 write!(
805 f,
806 "{}",
807 match self {
808 Plus => "",
810 Minus => "-",
811 }
812 )
813 }
814}
815
816#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
817#[cfg_attr(feature = "stubs", gen_stub_pyclass_enum)]
818#[cfg_attr(
819 feature = "python",
820 pyo3::pyclass(
821 module = "quil.expression",
822 eq,
823 frozen,
824 hash,
825 str,
826 rename_all = "SCREAMING_SNAKE_CASE"
827 )
828)]
829#[cfg_attr(test, derive(Arbitrary))]
830pub enum InfixOperator {
831 Caret,
832 Plus,
833 Minus,
834 Slash,
835 Star,
836}
837
838impl fmt::Display for InfixOperator {
839 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
840 use InfixOperator::*;
841 write!(
842 f,
843 "{}",
844 match self {
845 Caret => "^",
846 Plus => "+",
847 Minus => " - ",
849 Slash => "/",
850 Star => "*",
851 }
852 )
853 }
854}
855
856pub mod interned {
859 use super::*;
860
861 macro_rules! atoms {
862 ($($func:ident: $ctor:ident($($typ:ty)?)),+ $(,)?) => {
863 $(
864 #[doc = concat!(
865 "A wrapper around [`Expression::",
866 stringify!($ctor),
867 "`] that returns an [`ArcIntern<Expression>`]."
868 )]
869 #[inline(always)]
870 pub fn $func($(value: $typ)?) -> ArcIntern<Expression> {
871 ArcIntern::new(Expression::$ctor($(std::convert::identity::<$typ>(value))?))
874 }
875 )+
876 };
877 }
878
879 macro_rules! expression_wrappers {
880 ($($func:ident: $atom:ident($ctor:ident { $($field:ident: $field_ty:ty),*$(,)? })),+ $(,)?) => {
881 paste::paste! { $(
882 #[doc = concat!(
883 "A wrapper around [`Expression::", stringify!([<$func:camel>]), "`] ",
884 "that takes the contents of the inner expression type as arguments directly ",
885 "and returns an [`ArcIntern<Expression>`].",
886 "\n\n",
887 "See also [`", stringify!($atom), "`].",
888 )]
889 #[inline(always)]
890 pub fn $func($($field: $field_ty),*) -> ArcIntern<Expression> {
891 $atom($ctor { $($field),* })
892 }
893 )+ }
894 };
895 }
896
897 macro_rules! function_wrappers {
898 ($($func:ident: $ctor:ident),+ $(,)?) => {
899 $(
900 #[doc = concat!(
901 "Create an <code>[ArcIntern]<[Expression]></code> representing ",
902 "`", stringify!($func), "(expression)`.",
903 "\n\n",
904 "A wrapper around [`Expression::FunctionCall`] with ",
905 "[`ExpressionFunction::", stringify!($ctor), "`].",
906 )]
907 #[inline(always)]
908 pub fn $func(expression: ArcIntern<Expression>) -> ArcIntern<Expression> {
909 function_call(ExpressionFunction::$ctor, expression)
910 }
911 )+
912 };
913 }
914
915 macro_rules! infix_wrappers {
916 ($($func:ident: $ctor:ident ($op:tt)),+ $(,)?) => {
917 $(
918 #[doc = concat!(
919 "Create an <code>[ArcIntern]<[Expression]></code> representing ",
920 "`left ", stringify!($op), " right`.",
921 "\n\n",
922 "A wrapper around [`Expression::Infix`] with ",
923 "[`InfixOperator::", stringify!($ctor), "`].",
924 )]
925 #[inline(always)]
926 pub fn $func(
927 left: ArcIntern<Expression>,
928 right: ArcIntern<Expression>,
929 ) -> ArcIntern<Expression> {
930 infix(left, InfixOperator::$ctor, right)
931 }
932 )+
933 };
934 }
935
936 macro_rules! prefix_wrappers {
937 ($($func:ident: $ctor:ident ($op:tt)),+ $(,)?) => {
938 $(
939 #[doc = concat!(
940 "Create an <code>[ArcIntern]<[Expression]></code> representing ",
941 "`", stringify!($op), "expression`.",
942 "\n\n",
943 "A wrapper around [`Expression::Prefix`] with ",
944 "[`PrefixOperator::", stringify!($ctor), "`].",
945 )]
946 #[inline(always)]
947 pub fn $func(expression: ArcIntern<Expression>) -> ArcIntern<Expression> {
948 prefix(PrefixOperator::$ctor, expression)
949 }
950 )+
951 };
952 }
953
954 atoms! {
955 function_call_expr: FunctionCall(FunctionCallExpression),
956 infix_expr: Infix(InfixExpression),
957 pi: PiConstant(),
958 number: Number(Complex64),
959 prefix_expr: Prefix(PrefixExpression),
960 variable: Variable(String),
961 }
962
963 expression_wrappers! {
964 function_call: function_call_expr(FunctionCallExpression {
965 function: ExpressionFunction,
966 expression: ArcIntern<Expression>,
967 }),
968
969 infix: infix_expr(InfixExpression {
970 left: ArcIntern<Expression>,
971 operator: InfixOperator,
972 right: ArcIntern<Expression>,
973 }),
974
975 prefix: prefix_expr(PrefixExpression {
976 operator: PrefixOperator,
977 expression: ArcIntern<Expression>,
978 }),
979 }
980
981 function_wrappers! {
982 cis: Cis,
983 cos: Cosine,
984 exp: Exponent,
985 sin: Sine,
986 sqrt: SquareRoot,
987 }
988
989 infix_wrappers! {
990 add: Plus (+),
991 sub: Minus (-),
992 mul: Star (*),
993 div: Slash (/),
994 pow: Caret (^),
995 }
996
997 prefix_wrappers! {
998 unary_plus: Plus (+),
999 neg: Minus (-),
1000 }
1001}
1002
1003#[cfg(test)]
1004#[allow(clippy::arc_with_non_send_sync)]
1007mod tests {
1008 use super::*;
1009 use crate::reserved::ReservedToken;
1010 use proptest::prelude::*;
1011 use std::collections::hash_map::DefaultHasher;
1012 use std::collections::HashSet;
1013
1014 #[inline]
1016 fn hash_to_u64<T: Hash>(t: &T) -> u64 {
1017 let mut s = DefaultHasher::new();
1018 t.hash(&mut s);
1019 s.finish()
1020 }
1021
1022 #[test]
1023 fn simplify_and_evaluate() {
1024 use Expression::*;
1025
1026 let one = real!(1.0);
1027 let empty_variables = HashMap::new();
1028
1029 let mut variables = HashMap::new();
1030 variables.insert("foo".to_owned(), real!(10f64));
1031 variables.insert("bar".to_owned(), real!(100f64));
1032
1033 let empty_memory = HashMap::new();
1034
1035 let mut memory_references = HashMap::new();
1036 memory_references.insert("theta", vec![1.0, 2.0]);
1037 memory_references.insert("beta", vec![3.0, 4.0]);
1038
1039 struct TestCase<'a> {
1040 expression: Expression,
1041 variables: &'a HashMap<String, Complex64>,
1042 memory_references: &'a HashMap<&'a str, Vec<f64>>,
1043 simplified: Expression,
1044 evaluated: Result<Complex64, EvaluationError>,
1045 }
1046
1047 let cases: Vec<TestCase> = vec![
1048 TestCase {
1049 expression: Number(one),
1050 variables: &empty_variables,
1051 memory_references: &empty_memory,
1052 simplified: Number(one),
1053 evaluated: Ok(one),
1054 },
1055 TestCase {
1056 expression: Expression::Prefix(PrefixExpression {
1057 operator: PrefixOperator::Minus,
1058 expression: ArcIntern::new(Number(real!(1f64))),
1059 }),
1060 variables: &empty_variables,
1061 memory_references: &empty_memory,
1062 simplified: Number(real!(-1f64)),
1063 evaluated: Ok(real!(-1f64)),
1064 },
1065 TestCase {
1066 expression: Expression::Variable("foo".to_owned()),
1067 variables: &variables,
1068 memory_references: &empty_memory,
1069 simplified: Expression::Variable("foo".to_owned()),
1070 evaluated: Ok(real!(10f64)),
1071 },
1072 TestCase {
1073 expression: Expression::from_str("%foo + %bar").unwrap(),
1074 variables: &variables,
1075 memory_references: &empty_memory,
1076 simplified: Expression::from_str("%foo + %bar").unwrap(),
1077 evaluated: Ok(real!(110f64)),
1078 },
1079 TestCase {
1080 expression: Expression::FunctionCall(FunctionCallExpression {
1081 function: ExpressionFunction::Sine,
1082 expression: ArcIntern::new(Expression::Number(real!(PI / 2f64))),
1083 }),
1084 variables: &variables,
1085 memory_references: &empty_memory,
1086 simplified: Number(real!(1f64)),
1087 evaluated: Ok(real!(1f64)),
1088 },
1089 TestCase {
1090 expression: Expression::from_str("theta[1] * beta[0]").unwrap(),
1091 variables: &empty_variables,
1092 memory_references: &memory_references,
1093 simplified: Expression::from_str("theta[1] * beta[0]").unwrap(),
1094 evaluated: Ok(real!(6.0)),
1095 },
1096 ];
1097
1098 for mut case in cases {
1099 let evaluated = case
1100 .expression
1101 .evaluate(case.variables, case.memory_references);
1102 assert_eq!(evaluated, case.evaluated);
1103
1104 case.expression.simplify();
1105 assert_eq!(case.expression, case.simplified);
1106 }
1107 }
1108
1109 fn parenthesized(expression: &Expression) -> String {
1111 use Expression::*;
1112 match expression {
1113 Address(memory_reference) => memory_reference.to_quil_or_debug(),
1114 FunctionCall(FunctionCallExpression {
1115 function,
1116 expression,
1117 }) => format!("({function}({}))", parenthesized(expression)),
1118 Infix(InfixExpression {
1119 left,
1120 operator,
1121 right,
1122 }) => format!(
1123 "({}{}{})",
1124 parenthesized(left),
1125 operator,
1126 parenthesized(right)
1127 ),
1128 Number(value) => format!("({})", format_complex(value)),
1129 PiConstant() => "pi".to_string(),
1130 Prefix(PrefixExpression {
1131 operator,
1132 expression,
1133 }) => format!("({}{})", operator, parenthesized(expression)),
1134 Variable(identifier) => format!("(%{identifier})"),
1135 }
1136 }
1137
1138 fn arb_name() -> impl Strategy<Value = String> {
1140 r"[a-z][a-zA-Z0-9]{1,10}".prop_filter("Exclude reserved tokens", |t| {
1141 ReservedToken::from_str(t).is_err() && !t.to_lowercase().starts_with("nan")
1142 })
1143 }
1144
1145 fn arb_memory_reference() -> impl Strategy<Value = MemoryReference> {
1147 (arb_name(), (u64::MIN..u32::MAX as u64))
1148 .prop_map(|(name, index)| MemoryReference { name, index })
1149 }
1150
1151 fn arb_complex64() -> impl Strategy<Value = Complex64> {
1153 let tau = std::f64::consts::TAU;
1154 ((-tau..tau), (-tau..tau)).prop_map(|(re, im)| Complex64 { re, im })
1155 }
1156
1157 fn nonzero(strat: impl Strategy<Value = Expression>) -> impl Strategy<Value = Expression> {
1159 strat.prop_filter("Exclude constantly-zero expressions", |expr| {
1160 expr.clone().into_simplified() != Expression::Number(Complex64::new(0.0, 0.0))
1161 })
1162 }
1163
1164 fn arb_expr() -> impl Strategy<Value = Expression> {
1167 use Expression::*;
1168 let leaf = prop_oneof![
1169 arb_memory_reference().prop_map(Address),
1170 arb_complex64().prop_map(Number),
1171 Just(PiConstant()),
1172 arb_name().prop_map(Variable),
1173 ];
1174 leaf.prop_recursive(
1175 4, 64, 16, |expr| {
1179 let inner = expr.clone();
1180 prop_oneof![
1181 (any::<ExpressionFunction>(), expr.clone()).prop_map(|(function, e)| {
1182 Expression::FunctionCall(FunctionCallExpression {
1183 function,
1184 expression: ArcIntern::new(e),
1185 })
1186 }),
1187 (expr.clone(), any::<InfixOperator>())
1188 .prop_flat_map(move |(left, operator)| (
1189 Just(left),
1190 Just(operator),
1191 if let InfixOperator::Slash = operator {
1193 nonzero(inner.clone()).boxed()
1194 } else {
1195 inner.clone().boxed()
1196 }
1197 ))
1198 .prop_map(|(l, operator, r)| Infix(InfixExpression {
1199 left: ArcIntern::new(l),
1200 operator,
1201 right: ArcIntern::new(r)
1202 })),
1203 expr.prop_map(|e| Prefix(PrefixExpression {
1204 operator: PrefixOperator::Minus,
1205 expression: ArcIntern::new(e)
1206 }))
1207 ]
1208 },
1209 )
1210 }
1211
1212 proptest! {
1213
1214 #[test]
1215 fn eq(a in any::<f64>(), b in any::<f64>()) {
1216 let first = Expression::Infix (InfixExpression {
1217 left: ArcIntern::new(Expression::Number(real!(a))),
1218 operator: InfixOperator::Plus,
1219 right: ArcIntern::new(Expression::Number(real!(b))),
1220 } );
1221 let differing = Expression::Number(real!(a + b));
1222 prop_assert_eq!(&first, &first);
1223 prop_assert_ne!(&first, &differing);
1224 }
1225
1226 #[test]
1227 fn hash(a in any::<f64>(), b in any::<f64>()) {
1228 let first = Expression::Infix (InfixExpression {
1229 left: ArcIntern::new(Expression::Number(real!(a))),
1230 operator: InfixOperator::Plus,
1231 right: ArcIntern::new(Expression::Number(real!(b))),
1232 });
1233 let matching = first.clone();
1234 let differing = Expression::Number(real!(a + b));
1235 let mut set = HashSet::new();
1236 set.insert(first);
1237 assert!(set.contains(&matching));
1238 assert!(!set.contains(&differing))
1239 }
1240
1241 #[test]
1242 fn eq_iff_hash_eq(x in arb_expr(), y in arb_expr()) {
1243 prop_assert_eq!(x == y, hash_to_u64(&x) == hash_to_u64(&y));
1244 }
1245
1246 #[test]
1247 fn reals_are_real(x in any::<f64>()) {
1248 prop_assert_eq!(Expression::Number(real!(x)).to_real(), Ok(x))
1249 }
1250
1251 #[test]
1252 fn some_nums_are_real(re in any::<f64>(), im in any::<f64>()) {
1253 let result = Expression::Number(Complex64{re, im}).to_real();
1254 if is_small(im) {
1255 prop_assert_eq!(result, Ok(re))
1256 } else {
1257 prop_assert_eq!(result, Err(EvaluationError::NumberNotReal))
1258 }
1259 }
1260
1261 #[test]
1262 fn no_other_exps_are_real(expr in arb_expr().prop_filter("Not numbers", |e| !matches!(e, Expression::Number(_) | Expression::PiConstant()))) {
1263 prop_assert_eq!(expr.to_real(), Err(EvaluationError::NotANumber))
1264 }
1265
1266 #[test]
1267 fn complexes_are_parseable_as_expressions(value in arb_complex64()) {
1268 let parsed = Expression::from_str(&format_complex(&value));
1269 assert!(parsed.is_ok());
1270 let simple = parsed.unwrap().into_simplified();
1271 assert_eq!(Expression::Number(value), simple);
1272 }
1273
1274 #[test]
1275 fn exponentiation_works_as_expected(left in arb_expr(), right in arb_expr()) {
1276 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Caret, right: ArcIntern::new(right.clone()) } );
1277 prop_assert_eq!(left ^ right, expected);
1278 }
1279
1280 #[test]
1281 fn in_place_exponentiation_works_as_expected(left in arb_expr(), right in arb_expr()) {
1282 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Caret, right: ArcIntern::new(right.clone()) } );
1283 let mut x = left;
1284 x ^= right;
1285 prop_assert_eq!(x, expected);
1286 }
1287
1288 #[test]
1289 fn addition_works_as_expected(left in arb_expr(), right in arb_expr()) {
1290 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Plus, right: ArcIntern::new(right.clone()) } );
1291 prop_assert_eq!(left + right, expected);
1292 }
1293
1294 #[test]
1295 fn in_place_addition_works_as_expected(left in arb_expr(), right in arb_expr()) {
1296 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Plus, right: ArcIntern::new(right.clone()) } );
1297 let mut x = left;
1298 x += right;
1299 prop_assert_eq!(x, expected);
1300 }
1301
1302 #[test]
1303 fn subtraction_works_as_expected(left in arb_expr(), right in arb_expr()) {
1304 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Minus, right: ArcIntern::new(right.clone()) } );
1305 prop_assert_eq!(left - right, expected);
1306 }
1307
1308 #[test]
1309 fn in_place_subtraction_works_as_expected(left in arb_expr(), right in arb_expr()) {
1310 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Minus, right: ArcIntern::new(right.clone()) } );
1311 let mut x = left;
1312 x -= right;
1313 prop_assert_eq!(x, expected);
1314 }
1315
1316 #[test]
1317 fn multiplication_works_as_expected(left in arb_expr(), right in arb_expr()) {
1318 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Star, right: ArcIntern::new(right.clone()) } );
1319 prop_assert_eq!(left * right, expected);
1320 }
1321
1322 #[test]
1323 fn in_place_multiplication_works_as_expected(left in arb_expr(), right in arb_expr()) {
1324 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Star, right: ArcIntern::new(right.clone()) } );
1325 let mut x = left;
1326 x *= right;
1327 prop_assert_eq!(x, expected);
1328 }
1329
1330
1331 #[test]
1333 fn division_works_as_expected(left in arb_expr(), right in nonzero(arb_expr())) {
1334 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Slash, right: ArcIntern::new(right.clone()) } );
1335 prop_assert_eq!(left / right, expected);
1336 }
1337
1338 #[test]
1340 fn in_place_division_works_as_expected(left in arb_expr(), right in nonzero(arb_expr())) {
1341 let expected = Expression::Infix (InfixExpression { left: ArcIntern::new(left.clone()), operator: InfixOperator::Slash, right: ArcIntern::new(right.clone()) } );
1342 let mut x = left;
1343 x /= right;
1344 prop_assert_eq!(x, expected);
1345 }
1346
1347 #[allow(clippy::redundant_clone)]
1349 #[test]
1350 fn round_trip(e in arb_expr()) {
1351 let simple_e = e.clone().into_simplified();
1352 let s = parenthesized(&e);
1353 let p = Expression::from_str(&s);
1354 prop_assert!(p.is_ok());
1355 let p = p.unwrap();
1356 let simple_p = p.clone().into_simplified();
1357
1358 prop_assert_eq!(
1359 &simple_p,
1360 &simple_e,
1361 "Simplified expressions should be equal:\nparenthesized {p} ({p:?}) extracted from {s} simplified to {simple_p}\nvs original {e} ({e:?}) simplified to {simple_e}",
1362 p=p.to_quil_or_debug(),
1363 s=s,
1364 e=e.to_quil_or_debug(),
1365 simple_p=simple_p.to_quil_or_debug(),
1366 simple_e=simple_e.to_quil_or_debug()
1367 );
1368 }
1369
1370 }
1371
1372 #[test]
1375 fn specific_round_trip_tests() {
1376 for input in &[
1377 "-1*(phases[0]+phases[1])",
1378 "(-1*(phases[0]+phases[1]))+(-1*(phases[0]+phases[1]))",
1379 ] {
1380 let parsed = Expression::from_str(input);
1381 let parsed = parsed.unwrap();
1382 let restring = parsed.to_quil_or_debug();
1383 assert_eq!(input, &restring);
1384 }
1385 }
1386
1387 #[test]
1388 fn test_nan_is_equal() {
1389 let left = Expression::Number(f64::NAN.into());
1390 let right = left.clone();
1391 assert_eq!(left, right);
1392 }
1393
1394 #[test]
1395 fn specific_simplification_tests() {
1396 for (input, expected) in [
1397 ("pi", Expression::Number(PI.into())),
1398 ("pi/2", Expression::Number((PI / 2.0).into())),
1399 ("pi * pi", Expression::Number((PI.powi(2)).into())),
1400 ("1.0/(1.0-1.0)", Expression::Number(f64::NAN.into())),
1401 (
1402 "(a[0]*2*pi)/6.283185307179586",
1403 Expression::Address(MemoryReference {
1404 name: String::from("a"),
1405 index: 0,
1406 }),
1407 ),
1408 ] {
1409 assert_eq!(
1410 Expression::from_str(input).unwrap().into_simplified(),
1411 expected
1412 )
1413 }
1414 }
1415
1416 #[test]
1417 fn specific_to_real_tests() {
1418 for (input, expected) in [
1419 (Expression::PiConstant(), Ok(PI)),
1420 (Expression::Number(Complex64 { re: 1.0, im: 0.0 }), Ok(1.0)),
1421 (
1422 Expression::Number(Complex64 { re: 1.0, im: 1.0 }),
1423 Err(EvaluationError::NumberNotReal),
1424 ),
1425 (
1426 Expression::Variable("Not a number".into()),
1427 Err(EvaluationError::NotANumber),
1428 ),
1429 ] {
1430 assert_eq!(input.to_real(), expected)
1431 }
1432 }
1433
1434 #[test]
1435 fn specific_format_complex_tests() {
1436 for (x, s) in &[
1437 (Complex64::new(0.0, 0.0), "0"),
1438 (Complex64::new(-0.0, 0.0), "0"),
1439 (Complex64::new(-0.0, -0.0), "0"),
1440 (Complex64::new(0.0, 1.0), "1.0i"),
1441 (Complex64::new(1.0, -1.0), "1-1.0i"),
1442 (Complex64::new(1.234, 0.0), "1.234"),
1443 (Complex64::new(0.0, 1.234), "1.234i"),
1444 (Complex64::new(-1.234, 0.0), "-1.234"),
1445 (Complex64::new(0.0, -1.234), "-1.234i"),
1446 (Complex64::new(1.234, 5.678), "1.234+5.678i"),
1447 (Complex64::new(-1.234, 5.678), "-1.234+5.678i"),
1448 (Complex64::new(1.234, -5.678), "1.234-5.678i"),
1449 (Complex64::new(-1.234, -5.678), "-1.234-5.678i"),
1450 (Complex64::new(1e100, 2e-100), "1e100+2.0e-100i"),
1451 ] {
1452 assert_eq!(format_complex(x), *s);
1453 }
1454 }
1455}