1use std::collections::HashMap;
2
3use quil_rs::expression::{
4 Expression, ExpressionFunction, FunctionCallExpression, InfixExpression, InfixOperator,
5 PrefixExpression, PrefixOperator,
6};
7
8use rigetti_pyo3::{
9 create_init_submodule, impl_from_str, impl_hash, impl_parse, impl_repr, impl_str,
10 num_complex::Complex64,
11 py_wrap_data_struct, py_wrap_error, py_wrap_simple_enum, py_wrap_union_enum,
12 pyo3::{
13 exceptions::PyValueError,
14 pymethods,
15 types::{PyComplex, PyString},
16 Py, PyResult, Python,
17 },
18 wrap_error, PyTryFrom, PyWrapper, PyWrapperMut, ToPython, ToPythonError,
19};
20
21use internment::ArcIntern;
22
23use crate::{impl_eq, impl_to_quil, instruction::PyMemoryReference};
24
25wrap_error!(RustEvaluationError(quil_rs::expression::EvaluationError));
26py_wrap_error!(quil, RustEvaluationError, EvaluationError, PyValueError);
27wrap_error!(RustParseExpressionError(quil_rs::program::ParseProgramError<Expression>));
28py_wrap_error!(
29 quil,
30 RustParseExpressionError,
31 ParseExpressionError,
32 PyValueError
33);
34
35py_wrap_union_enum! {
36 #[derive(Debug, Hash, PartialEq, Eq)]
37 #[pyo3(module="quil.expression")]
38 PyExpression(Expression) as "Expression" {
39 address: Address => PyMemoryReference,
40 function_call: FunctionCall => PyFunctionCallExpression,
41 infix: Infix => PyInfixExpression,
42 number: Number => Py<PyComplex>,
43 pi: PiConstant,
44 prefix: Prefix => PyPrefixExpression,
45 variable: Variable => Py<PyString>
46 }
47}
48impl_repr!(PyExpression);
49impl_to_quil!(PyExpression);
50impl_from_str!(PyExpression, RustParseExpressionError);
51impl_hash!(PyExpression);
52impl_parse!(PyExpression);
53impl_eq!(PyExpression);
54
55#[pymethods]
56impl PyExpression {
57 pub fn simplify(&mut self) {
58 self.as_inner_mut().simplify()
59 }
60
61 pub fn into_simplified(&self, py: Python<'_>) -> PyResult<Self> {
62 self.as_inner().clone().into_simplified().to_python(py)
63 }
64
65 pub fn evaluate(
66 &self,
67 variables: HashMap<String, Complex64>,
68 memory_references: HashMap<&str, Vec<f64>>,
69 ) -> PyResult<Complex64> {
70 self.as_inner()
71 .evaluate(&variables, &memory_references)
72 .map_err(RustEvaluationError::from)
73 .map_err(RustEvaluationError::to_py_err)
74 }
75
76 pub fn substitute_variables(
77 &self,
78 py: Python<'_>,
79 variable_values: HashMap<String, PyExpression>,
80 ) -> PyResult<Self> {
81 Ok(PyExpression(self.as_inner().clone().substitute_variables(
82 &HashMap::<String, Expression>::py_try_from(py, &variable_values)?,
83 )))
84 }
85
86 pub fn to_real(&self) -> PyResult<f64> {
87 self.as_inner()
88 .to_real()
89 .map_err(RustEvaluationError::from)
90 .map_err(RustEvaluationError::to_py_err)
91 }
92
93 pub fn __add__(&self, other: PyExpression) -> Self {
94 PyExpression(self.as_inner().clone() + other.as_inner().clone())
95 }
96
97 pub fn __sub__(&self, other: PyExpression) -> Self {
98 PyExpression(self.as_inner().clone() - other.as_inner().clone())
99 }
100
101 pub fn __mul__(&self, other: PyExpression) -> Self {
102 PyExpression(self.as_inner().clone() * other.as_inner().clone())
103 }
104
105 pub fn __truediv__(&self, other: PyExpression) -> Self {
106 PyExpression(self.as_inner().clone() / other.as_inner().clone())
107 }
108}
109
110py_wrap_data_struct! {
111 #[pyo3(subclass)]
112 #[derive(Debug)]
113 PyFunctionCallExpression(FunctionCallExpression) as "FunctionCallExpression" {
114 function: ExpressionFunction => PyExpressionFunction,
115 expression: ArcIntern<Expression> => PyExpression
116 }
117}
118impl_repr!(PyFunctionCallExpression);
119
120#[pymethods]
121impl PyFunctionCallExpression {
122 #[new]
123 pub fn new(
124 py: Python<'_>,
125 function: PyExpressionFunction,
126 expression: PyExpression,
127 ) -> PyResult<Self> {
128 Ok(PyFunctionCallExpression(FunctionCallExpression::new(
129 ExpressionFunction::py_try_from(py, &function)?,
130 ArcIntern::<Expression>::py_try_from(py, &expression)?,
131 )))
132 }
133}
134
135py_wrap_data_struct! {
136 #[derive(Debug)]
137 #[pyo3(subclass)]
138 PyInfixExpression(InfixExpression) as "InfixExpression" {
139 left: ArcIntern<Expression> => PyExpression,
140 operator: InfixOperator => PyInfixOperator,
141 right: ArcIntern<Expression> => PyExpression
142 }
143}
144impl_repr!(PyInfixExpression);
145
146#[pymethods]
147impl PyInfixExpression {
148 #[new]
149 pub fn new(
150 py: Python<'_>,
151 left: PyExpression,
152 operator: PyInfixOperator,
153 right: PyExpression,
154 ) -> PyResult<Self> {
155 Ok(PyInfixExpression(InfixExpression::new(
156 ArcIntern::<Expression>::py_try_from(py, &left)?,
157 InfixOperator::py_try_from(py, &operator)?,
158 ArcIntern::<Expression>::py_try_from(py, &right)?,
159 )))
160 }
161}
162
163py_wrap_data_struct! {
164 #[derive(Debug)]
165 #[pyo3(subclass)]
166 PyPrefixExpression(PrefixExpression) as "PrefixExpression" {
167 operator: PrefixOperator => PyPrefixOperator,
168 expression: ArcIntern<Expression> => PyExpression
169 }
170}
171
172#[pymethods]
173impl PyPrefixExpression {
174 #[new]
175 pub fn new(
176 py: Python<'_>,
177 operator: PyPrefixOperator,
178 expression: PyExpression,
179 ) -> PyResult<Self> {
180 Ok(PyPrefixExpression(PrefixExpression::new(
181 PrefixOperator::py_try_from(py, &operator)?,
182 ArcIntern::<Expression>::py_try_from(py, &expression)?,
183 )))
184 }
185}
186
187py_wrap_simple_enum! {
188 #[derive(Debug, PartialEq, Eq, Hash)]
189 PyExpressionFunction(ExpressionFunction) as "ExpressionFunction" {
190 Cis,
191 Cosine,
192 Exponent,
193 Sine,
194 SquareRoot
195 }
196}
197impl_repr!(PyExpressionFunction);
198impl_str!(PyExpressionFunction);
199impl_hash!(PyExpressionFunction);
200impl_eq!(PyExpressionFunction);
201
202py_wrap_simple_enum! {
203 #[derive(Debug, PartialEq, Eq, Hash)]
204 PyPrefixOperator(PrefixOperator) as "PrefixOperator" {
205 Plus,
206 Minus
207 }
208}
209impl_repr!(PyPrefixOperator);
210impl_str!(PyPrefixOperator);
211impl_hash!(PyPrefixOperator);
212impl_eq!(PyPrefixOperator);
213
214py_wrap_simple_enum! {
215 #[derive(Debug, PartialEq, Eq, Hash)]
216 PyInfixOperator(InfixOperator) as "InfixOperator" {
217 Caret,
218 Plus,
219 Minus,
220 Slash,
221 Star
222 }
223}
224impl_repr!(PyInfixOperator);
225impl_str!(PyInfixOperator);
226impl_hash!(PyInfixOperator);
227impl_eq!(PyInfixOperator);
228
229create_init_submodule! {
230 classes: [PyExpression, PyFunctionCallExpression, PyInfixExpression, PyPrefixExpression, PyExpressionFunction, PyPrefixOperator, PyInfixOperator],
231 errors: [EvaluationError, ParseExpressionError],
232}