1use crate::data::PyDataset;
2use laddu_core::{
3 amplitudes::{constant, parameter, Evaluator, Expression, ParameterLike, TestAmplitude},
4 f64, LadduError, LadduResult, ReadWrite, ThreadPoolManager,
5};
6use num::complex::Complex64;
7use numpy::{PyArray1, PyArray2};
8use pyo3::{
9 exceptions::PyTypeError,
10 prelude::*,
11 types::{PyBytes, PyList},
12};
13use std::collections::HashMap;
14
15fn install_with_threads<R: Send>(
16 threads: Option<usize>,
17 op: impl FnOnce() -> R + Send,
18) -> LadduResult<R> {
19 ThreadPoolManager::shared().install(threads, op)
20}
21
22#[pyclass(name = "Expression", module = "laddu", from_py_object)]
25#[derive(Clone)]
26pub struct PyExpression(pub Expression);
27
28#[pyfunction(name = "expr_sum")]
31pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
32 if terms.is_empty() {
33 return Ok(PyExpression(Expression::zero()));
34 }
35 if terms.len() == 1 {
36 let term = &terms[0];
37 if let Ok(expression) = term.extract::<PyExpression>() {
38 return Ok(expression);
39 }
40 return Err(PyTypeError::new_err("Item is not a PyExpression"));
41 }
42 let mut iter = terms.iter();
43 let Some(first_term) = iter.next() else {
44 return Ok(PyExpression(Expression::zero()));
45 };
46 let PyExpression(mut summation) = first_term
47 .extract::<PyExpression>()
48 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
49 for term in iter {
50 let PyExpression(expr) = term
51 .extract::<PyExpression>()
52 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
53 summation = summation + expr;
54 }
55 Ok(PyExpression(summation))
56}
57
58#[pyfunction(name = "expr_product")]
61pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
62 if terms.is_empty() {
63 return Ok(PyExpression(Expression::one()));
64 }
65 if terms.len() == 1 {
66 let term = &terms[0];
67 if let Ok(expression) = term.extract::<PyExpression>() {
68 return Ok(expression);
69 }
70 return Err(PyTypeError::new_err("Item is not a PyExpression"));
71 }
72 let mut iter = terms.iter();
73 let Some(first_term) = iter.next() else {
74 return Ok(PyExpression(Expression::one()));
75 };
76 let PyExpression(mut product) = first_term
77 .extract::<PyExpression>()
78 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
79 for term in iter {
80 let PyExpression(expr) = term
81 .extract::<PyExpression>()
82 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
83 product = product * expr;
84 }
85 Ok(PyExpression(product))
86}
87
88#[pyfunction(name = "Zero")]
91pub fn py_expr_zero() -> PyExpression {
92 PyExpression(Expression::zero())
93}
94
95#[pyfunction(name = "One")]
98pub fn py_expr_one() -> PyExpression {
99 PyExpression(Expression::one())
100}
101
102#[pymethods]
103impl PyExpression {
104 #[getter]
111 fn parameters(&self) -> Vec<String> {
112 self.0.parameters()
113 }
114 #[getter]
116 fn free_parameters(&self) -> Vec<String> {
117 self.0.free_parameters()
118 }
119 #[getter]
121 fn fixed_parameters(&self) -> Vec<String> {
122 self.0.fixed_parameters()
123 }
124 #[getter]
126 fn n_free(&self) -> usize {
127 self.0.n_free()
128 }
129 #[getter]
131 fn n_fixed(&self) -> usize {
132 self.0.n_fixed()
133 }
134 #[getter]
136 fn n_parameters(&self) -> usize {
137 self.0.n_parameters()
138 }
139 fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
152 Ok(PyEvaluator(self.0.load(&dataset.0)?))
153 }
154 fn real(&self) -> PyExpression {
156 PyExpression(self.0.real())
157 }
158 fn imag(&self) -> PyExpression {
160 PyExpression(self.0.imag())
161 }
162 fn conj(&self) -> PyExpression {
164 PyExpression(self.0.conj())
165 }
166 fn norm_sqr(&self) -> PyExpression {
168 PyExpression(self.0.norm_sqr())
169 }
170 fn fix(&self, name: &str, value: f64) -> PyResult<PyExpression> {
172 Ok(PyExpression(self.0.fix(name, value)?))
173 }
174 fn free(&self, name: &str) -> PyResult<PyExpression> {
176 Ok(PyExpression(self.0.free(name)?))
177 }
178 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyExpression> {
180 Ok(PyExpression(self.0.rename_parameter(old, new)?))
181 }
182 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyExpression> {
184 Ok(PyExpression(self.0.rename_parameters(&mapping)?))
185 }
186 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
187 if let Ok(other_expr) = other.extract::<PyExpression>() {
188 Ok(PyExpression(self.0.clone() + other_expr.0))
189 } else if let Ok(other_int) = other.extract::<usize>() {
190 if other_int == 0 {
191 Ok(PyExpression(self.0.clone()))
192 } else {
193 Err(PyTypeError::new_err(
194 "Addition with an integer for this type is only defined for 0",
195 ))
196 }
197 } else {
198 Err(PyTypeError::new_err("Unsupported operand type for +"))
199 }
200 }
201 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
202 if let Ok(other_expr) = other.extract::<PyExpression>() {
203 Ok(PyExpression(other_expr.0 + self.0.clone()))
204 } else if let Ok(other_int) = other.extract::<usize>() {
205 if other_int == 0 {
206 Ok(PyExpression(self.0.clone()))
207 } else {
208 Err(PyTypeError::new_err(
209 "Addition with an integer for this type is only defined for 0",
210 ))
211 }
212 } else {
213 Err(PyTypeError::new_err("Unsupported operand type for +"))
214 }
215 }
216 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
217 if let Ok(other_expr) = other.extract::<PyExpression>() {
218 Ok(PyExpression(self.0.clone() - other_expr.0))
219 } else {
220 Err(PyTypeError::new_err("Unsupported operand type for -"))
221 }
222 }
223 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
224 if let Ok(other_expr) = other.extract::<PyExpression>() {
225 Ok(PyExpression(other_expr.0 - self.0.clone()))
226 } else {
227 Err(PyTypeError::new_err("Unsupported operand type for -"))
228 }
229 }
230 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
231 if let Ok(other_expr) = other.extract::<PyExpression>() {
232 Ok(PyExpression(self.0.clone() * other_expr.0))
233 } else {
234 Err(PyTypeError::new_err("Unsupported operand type for *"))
235 }
236 }
237 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
238 if let Ok(other_expr) = other.extract::<PyExpression>() {
239 Ok(PyExpression(other_expr.0 * self.0.clone()))
240 } else {
241 Err(PyTypeError::new_err("Unsupported operand type for *"))
242 }
243 }
244 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
245 if let Ok(other_expr) = other.extract::<PyExpression>() {
246 Ok(PyExpression(self.0.clone() / other_expr.0))
247 } else {
248 Err(PyTypeError::new_err("Unsupported operand type for /"))
249 }
250 }
251 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
252 if let Ok(other_expr) = other.extract::<PyExpression>() {
253 Ok(PyExpression(other_expr.0 / self.0.clone()))
254 } else {
255 Err(PyTypeError::new_err("Unsupported operand type for /"))
256 }
257 }
258 fn __neg__(&self) -> PyExpression {
259 PyExpression(-self.0.clone())
260 }
261 fn __str__(&self) -> String {
262 format!("{}", self.0)
263 }
264 fn __repr__(&self) -> String {
265 format!("{:?}", self.0)
266 }
267
268 #[new]
269 fn new() -> Self {
270 Self(Expression::create_null())
271 }
272 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
273 Ok(PyBytes::new(
274 py,
275 serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
276 .map_err(LadduError::PickleError)?
277 .as_slice(),
278 ))
279 }
280 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
281 *self = Self(
282 serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
283 .map_err(LadduError::PickleError)?,
284 );
285 Ok(())
286 }
287}
288
289#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
296#[derive(Clone)]
297pub struct PyEvaluator(pub Evaluator);
298
299#[pymethods]
300impl PyEvaluator {
301 #[getter]
309 fn parameters(&self) -> Vec<String> {
310 self.0.parameters()
311 }
312 #[getter]
314 fn free_parameters(&self) -> Vec<String> {
315 self.0.free_parameters()
316 }
317 #[getter]
319 fn fixed_parameters(&self) -> Vec<String> {
320 self.0.fixed_parameters()
321 }
322 #[getter]
324 fn n_free(&self) -> usize {
325 self.0.n_free()
326 }
327 #[getter]
329 fn n_fixed(&self) -> usize {
330 self.0.n_fixed()
331 }
332 #[getter]
334 fn n_parameters(&self) -> usize {
335 self.0.n_parameters()
336 }
337 fn fix(&self, name: &str, value: f64) -> PyResult<PyEvaluator> {
339 Ok(PyEvaluator(self.0.fix(name, value)?))
340 }
341 fn free(&self, name: &str) -> PyResult<PyEvaluator> {
343 Ok(PyEvaluator(self.0.free(name)?))
344 }
345 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyEvaluator> {
347 Ok(PyEvaluator(self.0.rename_parameter(old, new)?))
348 }
349 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyEvaluator> {
351 Ok(PyEvaluator(self.0.rename_parameters(&mapping)?))
352 }
353 #[pyo3(signature = (arg, *, strict=true))]
370 fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
371 if let Ok(string_arg) = arg.extract::<String>() {
372 if strict {
373 self.0.activate_strict(&string_arg)?;
374 } else {
375 self.0.activate(&string_arg);
376 }
377 } else if let Ok(list_arg) = arg.cast::<PyList>() {
378 let vec: Vec<String> = list_arg.extract()?;
379 if strict {
380 self.0.activate_many_strict(&vec)?;
381 } else {
382 self.0.activate_many(&vec);
383 }
384 } else {
385 return Err(PyTypeError::new_err(
386 "Argument must be either a string or a list of strings",
387 ));
388 }
389 Ok(())
390 }
391 fn activate_all(&self) {
394 self.0.activate_all();
395 }
396 #[pyo3(signature = (arg, *, strict=true))]
415 fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
416 if let Ok(string_arg) = arg.extract::<String>() {
417 if strict {
418 self.0.deactivate_strict(&string_arg)?;
419 } else {
420 self.0.deactivate(&string_arg);
421 }
422 } else if let Ok(list_arg) = arg.cast::<PyList>() {
423 let vec: Vec<String> = list_arg.extract()?;
424 if strict {
425 self.0.deactivate_many_strict(&vec)?;
426 } else {
427 self.0.deactivate_many(&vec);
428 }
429 } else {
430 return Err(PyTypeError::new_err(
431 "Argument must be either a string or a list of strings",
432 ));
433 }
434 Ok(())
435 }
436 fn deactivate_all(&self) {
439 self.0.deactivate_all();
440 }
441 #[pyo3(signature = (arg, *, strict=true))]
460 fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
461 if let Ok(string_arg) = arg.extract::<String>() {
462 if strict {
463 self.0.isolate_strict(&string_arg)?;
464 } else {
465 self.0.isolate(&string_arg);
466 }
467 } else if let Ok(list_arg) = arg.cast::<PyList>() {
468 let vec: Vec<String> = list_arg.extract()?;
469 if strict {
470 self.0.isolate_many_strict(&vec)?;
471 } else {
472 self.0.isolate_many(&vec);
473 }
474 } else {
475 return Err(PyTypeError::new_err(
476 "Argument must be either a string or a list of strings",
477 ));
478 }
479 Ok(())
480 }
481
482 #[getter]
484 fn active_mask(&self) -> Vec<bool> {
485 self.0.active_mask()
486 }
487
488 fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
490 self.0.set_active_mask(&mask)?;
491 Ok(())
492 }
493
494 #[pyo3(signature = (parameters, *, threads=None))]
516 fn evaluate<'py>(
517 &self,
518 py: Python<'py>,
519 parameters: Vec<f64>,
520 threads: Option<usize>,
521 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
522 let values = install_with_threads(threads, || self.0.evaluate(¶meters))?;
523 Ok(PyArray1::from_slice(py, &values))
524 }
525 #[pyo3(signature = (parameters, indices, *, threads=None))]
549 fn evaluate_batch<'py>(
550 &self,
551 py: Python<'py>,
552 parameters: Vec<f64>,
553 indices: Vec<usize>,
554 threads: Option<usize>,
555 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
556 let values =
557 install_with_threads(threads, || self.0.evaluate_batch(¶meters, &indices))?;
558 Ok(PyArray1::from_slice(py, &values))
559 }
560 #[pyo3(signature = (parameters, *, threads=None))]
583 fn evaluate_gradient<'py>(
584 &self,
585 py: Python<'py>,
586 parameters: Vec<f64>,
587 threads: Option<usize>,
588 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
589 let gradients = install_with_threads(threads, || {
590 self.0
591 .evaluate_gradient(¶meters)
592 .iter()
593 .map(|grad| grad.data.as_vec().to_vec())
594 .collect::<Vec<Vec<Complex64>>>()
595 })?;
596 Ok(PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?)
597 }
598 #[pyo3(signature = (parameters, indices, *, threads=None))]
623 fn evaluate_gradient_batch<'py>(
624 &self,
625 py: Python<'py>,
626 parameters: Vec<f64>,
627 indices: Vec<usize>,
628 threads: Option<usize>,
629 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
630 let gradients = install_with_threads(threads, || {
631 self.0
632 .evaluate_gradient_batch(¶meters, &indices)
633 .iter()
634 .map(|grad| grad.data.as_vec().to_vec())
635 .collect::<Vec<Vec<Complex64>>>()
636 })?;
637 Ok(PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?)
638 }
639}
640
641#[pyclass(name = "ParameterLike", module = "laddu", from_py_object)]
650#[derive(Clone)]
651pub struct PyParameterLike(pub ParameterLike);
652
653#[pyfunction(name = "parameter")]
670pub fn py_parameter(name: &str) -> PyParameterLike {
671 PyParameterLike(parameter(name))
672}
673
674#[pyfunction(name = "constant")]
689pub fn py_constant(name: &str, value: f64) -> PyParameterLike {
690 PyParameterLike(constant(name, value))
691}
692
693#[pyfunction(name = "TestAmplitude")]
695pub fn py_test_amplitude(
696 name: &str,
697 re: PyParameterLike,
698 im: PyParameterLike,
699) -> PyResult<PyExpression> {
700 Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
701}