1use crate::data::PyDataset;
2use laddu_core::{
3 amplitudes::{constant, parameter, Evaluator, Expression, ParameterLike, TestAmplitude},
4 f64, CompiledExpression, LadduError, LadduResult, ReadWrite, ThreadPoolManager,
5};
6use num::complex::Complex64;
7use numpy::{PyArray1, PyArray2};
8use pyo3::{
9 exceptions::PyTypeError,
10 prelude::*,
11 types::{PyBytes, PyList},
12};
13use std::collections::HashMap;
14
15fn install_with_threads<R: Send>(
16 threads: Option<usize>,
17 op: impl FnOnce() -> R + Send,
18) -> LadduResult<R> {
19 ThreadPoolManager::shared().install(threads, op)
20}
21
22#[pyclass(name = "Expression", module = "laddu", from_py_object)]
25#[derive(Clone)]
26pub struct PyExpression(pub Expression);
27
28#[pyfunction(name = "expr_sum")]
31pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
32 if terms.is_empty() {
33 return Ok(PyExpression(Expression::zero()));
34 }
35 if terms.len() == 1 {
36 let term = &terms[0];
37 if let Ok(expression) = term.extract::<PyExpression>() {
38 return Ok(expression);
39 }
40 return Err(PyTypeError::new_err("Item is not a PyExpression"));
41 }
42 let mut iter = terms.iter();
43 let Some(first_term) = iter.next() else {
44 return Ok(PyExpression(Expression::zero()));
45 };
46 let PyExpression(mut summation) = first_term
47 .extract::<PyExpression>()
48 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
49 for term in iter {
50 let PyExpression(expr) = term
51 .extract::<PyExpression>()
52 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
53 summation = summation + expr;
54 }
55 Ok(PyExpression(summation))
56}
57
58#[pyfunction(name = "expr_product")]
61pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
62 if terms.is_empty() {
63 return Ok(PyExpression(Expression::one()));
64 }
65 if terms.len() == 1 {
66 let term = &terms[0];
67 if let Ok(expression) = term.extract::<PyExpression>() {
68 return Ok(expression);
69 }
70 return Err(PyTypeError::new_err("Item is not a PyExpression"));
71 }
72 let mut iter = terms.iter();
73 let Some(first_term) = iter.next() else {
74 return Ok(PyExpression(Expression::one()));
75 };
76 let PyExpression(mut product) = first_term
77 .extract::<PyExpression>()
78 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
79 for term in iter {
80 let PyExpression(expr) = term
81 .extract::<PyExpression>()
82 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
83 product = product * expr;
84 }
85 Ok(PyExpression(product))
86}
87
88#[pyfunction(name = "Zero")]
91pub fn py_expr_zero() -> PyExpression {
92 PyExpression(Expression::zero())
93}
94
95#[pyfunction(name = "One")]
98pub fn py_expr_one() -> PyExpression {
99 PyExpression(Expression::one())
100}
101
102#[pymethods]
103impl PyExpression {
104 #[getter]
111 fn parameters(&self) -> Vec<String> {
112 self.0.parameters()
113 }
114 #[getter]
116 fn free_parameters(&self) -> Vec<String> {
117 self.0.free_parameters()
118 }
119 #[getter]
121 fn fixed_parameters(&self) -> Vec<String> {
122 self.0.fixed_parameters()
123 }
124 #[getter]
126 fn n_free(&self) -> usize {
127 self.0.n_free()
128 }
129 #[getter]
131 fn n_fixed(&self) -> usize {
132 self.0.n_fixed()
133 }
134 #[getter]
136 fn n_parameters(&self) -> usize {
137 self.0.n_parameters()
138 }
139 fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
152 Ok(PyEvaluator(self.0.load(&dataset.0)?))
153 }
154 fn real(&self) -> PyExpression {
156 PyExpression(self.0.real())
157 }
158 fn imag(&self) -> PyExpression {
160 PyExpression(self.0.imag())
161 }
162 fn conj(&self) -> PyExpression {
164 PyExpression(self.0.conj())
165 }
166 fn norm_sqr(&self) -> PyExpression {
168 PyExpression(self.0.norm_sqr())
169 }
170 fn sqrt(&self) -> PyExpression {
172 PyExpression(self.0.sqrt())
173 }
174 fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
176 if let Ok(expression) = power.extract::<PyExpression>() {
177 Ok(PyExpression(self.0.pow(&expression.0)))
178 } else if let Ok(value) = power.extract::<i32>() {
179 Ok(PyExpression(self.0.powi(value)))
180 } else if let Ok(value) = power.extract::<f64>() {
181 Ok(PyExpression(self.0.powf(value)))
182 } else {
183 Err(PyTypeError::new_err(
184 "power must be an int, float, or Expression",
185 ))
186 }
187 }
188 fn exp(&self) -> PyExpression {
190 PyExpression(self.0.exp())
191 }
192 fn sin(&self) -> PyExpression {
194 PyExpression(self.0.sin())
195 }
196 fn cos(&self) -> PyExpression {
198 PyExpression(self.0.cos())
199 }
200 fn log(&self) -> PyExpression {
202 PyExpression(self.0.log())
203 }
204 fn cis(&self) -> PyExpression {
206 PyExpression(self.0.cis())
207 }
208 fn fix(&self, name: &str, value: f64) -> PyResult<PyExpression> {
210 Ok(PyExpression(self.0.fix(name, value)?))
211 }
212 fn free(&self, name: &str) -> PyResult<PyExpression> {
214 Ok(PyExpression(self.0.free(name)?))
215 }
216 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyExpression> {
218 Ok(PyExpression(self.0.rename_parameter(old, new)?))
219 }
220 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyExpression> {
222 Ok(PyExpression(self.0.rename_parameters(&mapping)?))
223 }
224 #[getter]
226 fn compiled_expression(&self) -> PyCompiledExpression {
227 PyCompiledExpression(self.0.compiled_expression())
228 }
229 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
230 if let Ok(other_expr) = other.extract::<PyExpression>() {
231 Ok(PyExpression(self.0.clone() + other_expr.0))
232 } else if let Ok(other_int) = other.extract::<usize>() {
233 if other_int == 0 {
234 Ok(PyExpression(self.0.clone()))
235 } else {
236 Err(PyTypeError::new_err(
237 "Addition with an integer for this type is only defined for 0",
238 ))
239 }
240 } else {
241 Err(PyTypeError::new_err("Unsupported operand type for +"))
242 }
243 }
244 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
245 if let Ok(other_expr) = other.extract::<PyExpression>() {
246 Ok(PyExpression(other_expr.0 + self.0.clone()))
247 } else if let Ok(other_int) = other.extract::<usize>() {
248 if other_int == 0 {
249 Ok(PyExpression(self.0.clone()))
250 } else {
251 Err(PyTypeError::new_err(
252 "Addition with an integer for this type is only defined for 0",
253 ))
254 }
255 } else {
256 Err(PyTypeError::new_err("Unsupported operand type for +"))
257 }
258 }
259 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
260 if let Ok(other_expr) = other.extract::<PyExpression>() {
261 Ok(PyExpression(self.0.clone() - other_expr.0))
262 } else {
263 Err(PyTypeError::new_err("Unsupported operand type for -"))
264 }
265 }
266 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
267 if let Ok(other_expr) = other.extract::<PyExpression>() {
268 Ok(PyExpression(other_expr.0 - self.0.clone()))
269 } else {
270 Err(PyTypeError::new_err("Unsupported operand type for -"))
271 }
272 }
273 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
274 if let Ok(other_expr) = other.extract::<PyExpression>() {
275 Ok(PyExpression(self.0.clone() * other_expr.0))
276 } else {
277 Err(PyTypeError::new_err("Unsupported operand type for *"))
278 }
279 }
280 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
281 if let Ok(other_expr) = other.extract::<PyExpression>() {
282 Ok(PyExpression(other_expr.0 * self.0.clone()))
283 } else {
284 Err(PyTypeError::new_err("Unsupported operand type for *"))
285 }
286 }
287 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
288 if let Ok(other_expr) = other.extract::<PyExpression>() {
289 Ok(PyExpression(self.0.clone() / other_expr.0))
290 } else {
291 Err(PyTypeError::new_err("Unsupported operand type for /"))
292 }
293 }
294 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
295 if let Ok(other_expr) = other.extract::<PyExpression>() {
296 Ok(PyExpression(other_expr.0 / self.0.clone()))
297 } else {
298 Err(PyTypeError::new_err("Unsupported operand type for /"))
299 }
300 }
301 fn __neg__(&self) -> PyExpression {
302 PyExpression(-self.0.clone())
303 }
304 fn __str__(&self) -> String {
305 format!("{}", self.0)
306 }
307 fn __repr__(&self) -> String {
308 format!("{:?}", self.0)
309 }
310
311 #[new]
312 fn new() -> Self {
313 Self(Expression::create_null())
314 }
315 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
316 Ok(PyBytes::new(
317 py,
318 serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
319 .map_err(LadduError::PickleError)?
320 .as_slice(),
321 ))
322 }
323 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
324 *self = Self(
325 serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
326 .map_err(LadduError::PickleError)?,
327 );
328 Ok(())
329 }
330}
331
332#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
339#[derive(Clone)]
340pub struct PyEvaluator(pub Evaluator);
341
342#[pymethods]
343impl PyEvaluator {
344 #[getter]
352 fn parameters(&self) -> Vec<String> {
353 self.0.parameters()
354 }
355 #[getter]
357 fn free_parameters(&self) -> Vec<String> {
358 self.0.free_parameters()
359 }
360 #[getter]
362 fn fixed_parameters(&self) -> Vec<String> {
363 self.0.fixed_parameters()
364 }
365 #[getter]
367 fn n_free(&self) -> usize {
368 self.0.n_free()
369 }
370 #[getter]
372 fn n_fixed(&self) -> usize {
373 self.0.n_fixed()
374 }
375 #[getter]
377 fn n_parameters(&self) -> usize {
378 self.0.n_parameters()
379 }
380 fn fix(&self, name: &str, value: f64) -> PyResult<PyEvaluator> {
382 Ok(PyEvaluator(self.0.fix(name, value)?))
383 }
384 fn free(&self, name: &str) -> PyResult<PyEvaluator> {
386 Ok(PyEvaluator(self.0.free(name)?))
387 }
388 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyEvaluator> {
390 Ok(PyEvaluator(self.0.rename_parameter(old, new)?))
391 }
392 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyEvaluator> {
394 Ok(PyEvaluator(self.0.rename_parameters(&mapping)?))
395 }
396 #[pyo3(signature = (arg, *, strict=true))]
413 fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
414 if let Ok(string_arg) = arg.extract::<String>() {
415 if strict {
416 self.0.activate_strict(&string_arg)?;
417 } else {
418 self.0.activate(&string_arg);
419 }
420 } else if let Ok(list_arg) = arg.cast::<PyList>() {
421 let vec: Vec<String> = list_arg.extract()?;
422 if strict {
423 self.0.activate_many_strict(&vec)?;
424 } else {
425 self.0.activate_many(&vec);
426 }
427 } else {
428 return Err(PyTypeError::new_err(
429 "Argument must be either a string or a list of strings",
430 ));
431 }
432 Ok(())
433 }
434 fn activate_all(&self) {
437 self.0.activate_all();
438 }
439 #[pyo3(signature = (arg, *, strict=true))]
458 fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
459 if let Ok(string_arg) = arg.extract::<String>() {
460 if strict {
461 self.0.deactivate_strict(&string_arg)?;
462 } else {
463 self.0.deactivate(&string_arg);
464 }
465 } else if let Ok(list_arg) = arg.cast::<PyList>() {
466 let vec: Vec<String> = list_arg.extract()?;
467 if strict {
468 self.0.deactivate_many_strict(&vec)?;
469 } else {
470 self.0.deactivate_many(&vec);
471 }
472 } else {
473 return Err(PyTypeError::new_err(
474 "Argument must be either a string or a list of strings",
475 ));
476 }
477 Ok(())
478 }
479 fn deactivate_all(&self) {
482 self.0.deactivate_all();
483 }
484 #[pyo3(signature = (arg, *, strict=true))]
503 fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
504 if let Ok(string_arg) = arg.extract::<String>() {
505 if strict {
506 self.0.isolate_strict(&string_arg)?;
507 } else {
508 self.0.isolate(&string_arg);
509 }
510 } else if let Ok(list_arg) = arg.cast::<PyList>() {
511 let vec: Vec<String> = list_arg.extract()?;
512 if strict {
513 self.0.isolate_many_strict(&vec)?;
514 } else {
515 self.0.isolate_many(&vec);
516 }
517 } else {
518 return Err(PyTypeError::new_err(
519 "Argument must be either a string or a list of strings",
520 ));
521 }
522 Ok(())
523 }
524
525 #[getter]
527 fn active_mask(&self) -> Vec<bool> {
528 self.0.active_mask()
529 }
530
531 fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
533 self.0.set_active_mask(&mask)?;
534 Ok(())
535 }
536
537 #[getter]
539 fn compiled_expression(&self) -> PyCompiledExpression {
540 PyCompiledExpression(self.0.compiled_expression())
541 }
542
543 #[getter]
545 fn expression(&self) -> PyExpression {
546 PyExpression(self.0.expression())
547 }
548
549 #[pyo3(signature = (parameters, *, threads=None))]
571 fn evaluate<'py>(
572 &self,
573 py: Python<'py>,
574 parameters: Vec<f64>,
575 threads: Option<usize>,
576 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
577 let values = install_with_threads(threads, || self.0.evaluate(¶meters))?;
578 Ok(PyArray1::from_slice(py, &values))
579 }
580 #[pyo3(signature = (parameters, indices, *, threads=None))]
604 fn evaluate_batch<'py>(
605 &self,
606 py: Python<'py>,
607 parameters: Vec<f64>,
608 indices: Vec<usize>,
609 threads: Option<usize>,
610 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
611 let values =
612 install_with_threads(threads, || self.0.evaluate_batch(¶meters, &indices))?;
613 Ok(PyArray1::from_slice(py, &values))
614 }
615 #[pyo3(signature = (parameters, *, threads=None))]
638 fn evaluate_gradient<'py>(
639 &self,
640 py: Python<'py>,
641 parameters: Vec<f64>,
642 threads: Option<usize>,
643 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
644 let gradients = install_with_threads(threads, || {
645 self.0
646 .evaluate_gradient(¶meters)
647 .iter()
648 .map(|grad| grad.data.as_vec().to_vec())
649 .collect::<Vec<Vec<Complex64>>>()
650 })?;
651 Ok(PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?)
652 }
653 #[pyo3(signature = (parameters, indices, *, threads=None))]
678 fn evaluate_gradient_batch<'py>(
679 &self,
680 py: Python<'py>,
681 parameters: Vec<f64>,
682 indices: Vec<usize>,
683 threads: Option<usize>,
684 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
685 let gradients = install_with_threads(threads, || {
686 self.0
687 .evaluate_gradient_batch(¶meters, &indices)
688 .iter()
689 .map(|grad| grad.data.as_vec().to_vec())
690 .collect::<Vec<Vec<Complex64>>>()
691 })?;
692 Ok(PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?)
693 }
694}
695
696#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
703#[derive(Clone)]
704pub struct PyCompiledExpression(pub CompiledExpression);
705
706#[pymethods]
707impl PyCompiledExpression {
708 fn __str__(&self) -> String {
709 format!("{}", self.0)
710 }
711 fn __repr__(&self) -> String {
712 format!("{:?}", self.0)
713 }
714}
715
716#[pyclass(name = "ParameterLike", module = "laddu", from_py_object)]
725#[derive(Clone)]
726pub struct PyParameterLike(pub ParameterLike);
727
728#[pyfunction(name = "parameter")]
745pub fn py_parameter(name: &str) -> PyParameterLike {
746 PyParameterLike(parameter(name))
747}
748
749#[pyfunction(name = "constant")]
764pub fn py_constant(name: &str, value: f64) -> PyParameterLike {
765 PyParameterLike(constant(name, value))
766}
767
768#[pyfunction(name = "TestAmplitude")]
770pub fn py_test_amplitude(
771 name: &str,
772 re: PyParameterLike,
773 im: PyParameterLike,
774) -> PyResult<PyExpression> {
775 Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
776}