Skip to main content

quil/
expression.rs

1use std::collections::HashMap;
2
3use quil_rs::expression::{
4    Expression, ExpressionFunction, FunctionCallExpression, InfixExpression, InfixOperator,
5    PrefixExpression, PrefixOperator,
6};
7
8use rigetti_pyo3::{
9    create_init_submodule, impl_from_str, impl_hash, impl_parse, impl_repr, impl_str,
10    num_complex::Complex64,
11    py_wrap_data_struct, py_wrap_error, py_wrap_simple_enum, py_wrap_union_enum,
12    pyo3::{
13        exceptions::PyValueError,
14        pymethods,
15        types::{PyComplex, PyString},
16        Py, PyResult, Python,
17    },
18    wrap_error, PyTryFrom, PyWrapper, PyWrapperMut, ToPython, ToPythonError,
19};
20
21use internment::ArcIntern;
22
23use crate::{impl_eq, impl_to_quil, instruction::PyMemoryReference};
24
25wrap_error!(RustEvaluationError(quil_rs::expression::EvaluationError));
26py_wrap_error!(quil, RustEvaluationError, EvaluationError, PyValueError);
27wrap_error!(RustParseExpressionError(quil_rs::program::ParseProgramError<Expression>));
28py_wrap_error!(
29    quil,
30    RustParseExpressionError,
31    ParseExpressionError,
32    PyValueError
33);
34
35py_wrap_union_enum! {
36    #[derive(Debug, Hash, PartialEq, Eq)]
37    #[pyo3(module="quil.expression")]
38    PyExpression(Expression) as "Expression" {
39        address: Address => PyMemoryReference,
40        function_call: FunctionCall => PyFunctionCallExpression,
41        infix: Infix => PyInfixExpression,
42        number: Number => Py<PyComplex>,
43        pi: PiConstant,
44        prefix: Prefix => PyPrefixExpression,
45        variable: Variable => Py<PyString>
46    }
47}
48impl_repr!(PyExpression);
49impl_to_quil!(PyExpression);
50impl_from_str!(PyExpression, RustParseExpressionError);
51impl_hash!(PyExpression);
52impl_parse!(PyExpression);
53impl_eq!(PyExpression);
54
55#[pymethods]
56impl PyExpression {
57    pub fn simplify(&mut self) {
58        self.as_inner_mut().simplify()
59    }
60
61    pub fn into_simplified(&self, py: Python<'_>) -> PyResult<Self> {
62        self.as_inner().clone().into_simplified().to_python(py)
63    }
64
65    pub fn evaluate(
66        &self,
67        variables: HashMap<String, Complex64>,
68        memory_references: HashMap<&str, Vec<f64>>,
69    ) -> PyResult<Complex64> {
70        self.as_inner()
71            .evaluate(&variables, &memory_references)
72            .map_err(RustEvaluationError::from)
73            .map_err(RustEvaluationError::to_py_err)
74    }
75
76    pub fn substitute_variables(
77        &self,
78        py: Python<'_>,
79        variable_values: HashMap<String, PyExpression>,
80    ) -> PyResult<Self> {
81        Ok(PyExpression(self.as_inner().clone().substitute_variables(
82            &HashMap::<String, Expression>::py_try_from(py, &variable_values)?,
83        )))
84    }
85
86    pub fn to_real(&self) -> PyResult<f64> {
87        self.as_inner()
88            .to_real()
89            .map_err(RustEvaluationError::from)
90            .map_err(RustEvaluationError::to_py_err)
91    }
92
93    pub fn __add__(&self, other: PyExpression) -> Self {
94        PyExpression(self.as_inner().clone() + other.as_inner().clone())
95    }
96
97    pub fn __sub__(&self, other: PyExpression) -> Self {
98        PyExpression(self.as_inner().clone() - other.as_inner().clone())
99    }
100
101    pub fn __mul__(&self, other: PyExpression) -> Self {
102        PyExpression(self.as_inner().clone() * other.as_inner().clone())
103    }
104
105    pub fn __truediv__(&self, other: PyExpression) -> Self {
106        PyExpression(self.as_inner().clone() / other.as_inner().clone())
107    }
108}
109
110py_wrap_data_struct! {
111    #[pyo3(subclass)]
112    #[derive(Debug)]
113    PyFunctionCallExpression(FunctionCallExpression) as "FunctionCallExpression" {
114        function: ExpressionFunction => PyExpressionFunction,
115        expression: ArcIntern<Expression> => PyExpression
116    }
117}
118impl_repr!(PyFunctionCallExpression);
119
120#[pymethods]
121impl PyFunctionCallExpression {
122    #[new]
123    pub fn new(
124        py: Python<'_>,
125        function: PyExpressionFunction,
126        expression: PyExpression,
127    ) -> PyResult<Self> {
128        Ok(PyFunctionCallExpression(FunctionCallExpression::new(
129            ExpressionFunction::py_try_from(py, &function)?,
130            ArcIntern::<Expression>::py_try_from(py, &expression)?,
131        )))
132    }
133}
134
135py_wrap_data_struct! {
136    #[derive(Debug)]
137    #[pyo3(subclass)]
138    PyInfixExpression(InfixExpression) as "InfixExpression" {
139        left: ArcIntern<Expression> => PyExpression,
140        operator: InfixOperator => PyInfixOperator,
141        right: ArcIntern<Expression> => PyExpression
142    }
143}
144impl_repr!(PyInfixExpression);
145
146#[pymethods]
147impl PyInfixExpression {
148    #[new]
149    pub fn new(
150        py: Python<'_>,
151        left: PyExpression,
152        operator: PyInfixOperator,
153        right: PyExpression,
154    ) -> PyResult<Self> {
155        Ok(PyInfixExpression(InfixExpression::new(
156            ArcIntern::<Expression>::py_try_from(py, &left)?,
157            InfixOperator::py_try_from(py, &operator)?,
158            ArcIntern::<Expression>::py_try_from(py, &right)?,
159        )))
160    }
161}
162
163py_wrap_data_struct! {
164    #[derive(Debug)]
165    #[pyo3(subclass)]
166    PyPrefixExpression(PrefixExpression) as "PrefixExpression" {
167        operator: PrefixOperator => PyPrefixOperator,
168        expression: ArcIntern<Expression> => PyExpression
169    }
170}
171
172#[pymethods]
173impl PyPrefixExpression {
174    #[new]
175    pub fn new(
176        py: Python<'_>,
177        operator: PyPrefixOperator,
178        expression: PyExpression,
179    ) -> PyResult<Self> {
180        Ok(PyPrefixExpression(PrefixExpression::new(
181            PrefixOperator::py_try_from(py, &operator)?,
182            ArcIntern::<Expression>::py_try_from(py, &expression)?,
183        )))
184    }
185}
186
187py_wrap_simple_enum! {
188    #[derive(Debug, PartialEq, Eq, Hash)]
189    PyExpressionFunction(ExpressionFunction) as "ExpressionFunction" {
190        Cis,
191        Cosine,
192        Exponent,
193        Sine,
194        SquareRoot
195    }
196}
197impl_repr!(PyExpressionFunction);
198impl_str!(PyExpressionFunction);
199impl_hash!(PyExpressionFunction);
200impl_eq!(PyExpressionFunction);
201
202py_wrap_simple_enum! {
203    #[derive(Debug, PartialEq, Eq, Hash)]
204    PyPrefixOperator(PrefixOperator) as "PrefixOperator" {
205        Plus,
206        Minus
207    }
208}
209impl_repr!(PyPrefixOperator);
210impl_str!(PyPrefixOperator);
211impl_hash!(PyPrefixOperator);
212impl_eq!(PyPrefixOperator);
213
214py_wrap_simple_enum! {
215    #[derive(Debug, PartialEq, Eq, Hash)]
216    PyInfixOperator(InfixOperator) as "InfixOperator" {
217        Caret,
218        Plus,
219        Minus,
220        Slash,
221        Star
222    }
223}
224impl_repr!(PyInfixOperator);
225impl_str!(PyInfixOperator);
226impl_hash!(PyInfixOperator);
227impl_eq!(PyInfixOperator);
228
229create_init_submodule! {
230    classes: [PyExpression, PyFunctionCallExpression, PyInfixExpression, PyPrefixExpression, PyExpressionFunction, PyPrefixOperator, PyInfixOperator],
231    errors: [EvaluationError, ParseExpressionError],
232}