1use crate::data::PyDataset;
2use laddu_core::{
3 amplitudes::{Evaluator, Expression, Parameter, TestAmplitude},
4 CompiledExpression, LadduError, LadduResult, ReadWrite, ThreadPoolManager,
5};
6use num::complex::Complex64;
7use numpy::{PyArray1, PyArray2};
8use pyo3::{
9 exceptions::PyTypeError,
10 prelude::*,
11 types::{PyBytes, PyList, PyTuple},
12};
13use std::collections::HashMap;
14
15fn install_with_threads<R: Send>(
16 threads: Option<usize>,
17 op: impl FnOnce() -> R + Send,
18) -> LadduResult<R> {
19 ThreadPoolManager::shared().install(threads, op)
20}
21
22#[pyclass(name = "Expression", module = "laddu", skip_from_py_object)]
25#[derive(Clone)]
26pub struct PyExpression(pub Expression);
27
28impl<'py> FromPyObject<'_, 'py> for PyExpression {
29 type Error = PyErr;
30
31 fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
32 if let Ok(obj) = obj.cast::<PyExpression>() {
33 Ok(obj.borrow().clone())
34 } else if let Ok(obj) = obj.extract::<f64>() {
35 Ok(Self(obj.into()))
36 } else if let Ok(obj) = obj.extract::<Complex64>() {
37 Ok(Self(obj.into()))
38 } else {
39 Err(PyTypeError::new_err("Failed to extract Expression"))
40 }
41 }
42}
43
44#[pyfunction(name = "expr_sum")]
47pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
48 if terms.is_empty() {
49 return Ok(PyExpression(Expression::zero()));
50 }
51 if terms.len() == 1 {
52 let term = &terms[0];
53 if let Ok(expression) = term.extract::<PyExpression>() {
54 return Ok(expression);
55 }
56 return Err(PyTypeError::new_err("Item is not a PyExpression"));
57 }
58 let mut iter = terms.iter();
59 let Some(first_term) = iter.next() else {
60 return Ok(PyExpression(Expression::zero()));
61 };
62 let PyExpression(mut summation) = first_term
63 .extract::<PyExpression>()
64 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
65 for term in iter {
66 let PyExpression(expr) = term
67 .extract::<PyExpression>()
68 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
69 summation = summation + expr;
70 }
71 Ok(PyExpression(summation))
72}
73
74#[pyfunction(name = "expr_product")]
77pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
78 if terms.is_empty() {
79 return Ok(PyExpression(Expression::one()));
80 }
81 if terms.len() == 1 {
82 let term = &terms[0];
83 if let Ok(expression) = term.extract::<PyExpression>() {
84 return Ok(expression);
85 }
86 return Err(PyTypeError::new_err("Item is not a PyExpression"));
87 }
88 let mut iter = terms.iter();
89 let Some(first_term) = iter.next() else {
90 return Ok(PyExpression(Expression::one()));
91 };
92 let PyExpression(mut product) = first_term
93 .extract::<PyExpression>()
94 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
95 for term in iter {
96 let PyExpression(expr) = term
97 .extract::<PyExpression>()
98 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
99 product = product * expr;
100 }
101 Ok(PyExpression(product))
102}
103
104#[pyfunction(name = "Zero")]
107pub fn py_expr_zero() -> PyExpression {
108 PyExpression(Expression::zero())
109}
110
111#[pyfunction(name = "One")]
114pub fn py_expr_one() -> PyExpression {
115 PyExpression(Expression::one())
116}
117
118#[pymethods]
119impl PyExpression {
120 #[getter]
127 fn parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
128 PyTuple::new(py, self.0.parameters())
129 }
130 #[getter]
132 fn free_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
133 PyTuple::new(py, self.0.free_parameters())
134 }
135 #[getter]
137 fn fixed_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
138 PyTuple::new(py, self.0.fixed_parameters())
139 }
140 #[getter]
142 fn n_free(&self) -> usize {
143 self.0.n_free()
144 }
145 #[getter]
147 fn n_fixed(&self) -> usize {
148 self.0.n_fixed()
149 }
150 #[getter]
152 fn n_parameters(&self) -> usize {
153 self.0.n_parameters()
154 }
155 fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
168 Ok(PyEvaluator(self.0.load(&dataset.0)?))
169 }
170 fn real(&self) -> PyExpression {
172 PyExpression(self.0.real())
173 }
174 fn imag(&self) -> PyExpression {
176 PyExpression(self.0.imag())
177 }
178 fn conj(&self) -> PyExpression {
180 PyExpression(self.0.conj())
181 }
182 fn norm_sqr(&self) -> PyExpression {
184 PyExpression(self.0.norm_sqr())
185 }
186 fn sqrt(&self) -> PyExpression {
188 PyExpression(self.0.sqrt())
189 }
190 fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
192 if let Ok(value) = power.extract::<i32>() {
193 Ok(PyExpression(self.0.powi(value)))
194 } else if let Ok(value) = power.extract::<f64>() {
195 Ok(PyExpression(self.0.powf(value)))
196 } else if let Ok(expression) = power.extract::<PyExpression>() {
197 Ok(PyExpression(self.0.pow(&expression.0)))
198 } else {
199 Err(PyTypeError::new_err(
200 "power must be an int, float, or Expression",
201 ))
202 }
203 }
204 fn exp(&self) -> PyExpression {
206 PyExpression(self.0.exp())
207 }
208 fn sin(&self) -> PyExpression {
210 PyExpression(self.0.sin())
211 }
212 fn cos(&self) -> PyExpression {
214 PyExpression(self.0.cos())
215 }
216 fn log(&self) -> PyExpression {
218 PyExpression(self.0.log())
219 }
220 fn cis(&self) -> PyExpression {
222 PyExpression(self.0.cis())
223 }
224 fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
226 Ok(self.0.fix_parameter(name, value)?)
227 }
228 fn free_parameter(&self, name: &str) -> PyResult<()> {
230 Ok(self.0.free_parameter(name)?)
231 }
232 fn rename_parameter(&mut self, old: &str, new: &str) -> PyResult<()> {
234 Ok(self.0.rename_parameter(old, new)?)
235 }
236 fn rename_parameters(&mut self, mapping: HashMap<String, String>) -> PyResult<()> {
238 Ok(self.0.rename_parameters(&mapping)?)
239 }
240 #[getter]
242 fn compiled_expression(&self) -> PyCompiledExpression {
243 PyCompiledExpression(self.0.compiled_expression())
244 }
245 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
246 if let Ok(other_expr) = other.extract::<PyExpression>() {
247 Ok(PyExpression(self.0.clone() + other_expr.0))
248 } else {
249 Err(PyTypeError::new_err("Unsupported operand type for +"))
250 }
251 }
252 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
253 if let Ok(other_expr) = other.extract::<PyExpression>() {
254 Ok(PyExpression(other_expr.0 + self.0.clone()))
255 } else {
256 Err(PyTypeError::new_err("Unsupported operand type for +"))
257 }
258 }
259 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
260 if let Ok(other_expr) = other.extract::<PyExpression>() {
261 Ok(PyExpression(self.0.clone() - other_expr.0))
262 } else {
263 Err(PyTypeError::new_err("Unsupported operand type for -"))
264 }
265 }
266 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
267 if let Ok(other_expr) = other.extract::<PyExpression>() {
268 Ok(PyExpression(other_expr.0 - self.0.clone()))
269 } else {
270 Err(PyTypeError::new_err("Unsupported operand type for -"))
271 }
272 }
273 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
274 if let Ok(other_expr) = other.extract::<PyExpression>() {
275 Ok(PyExpression(self.0.clone() * other_expr.0))
276 } else {
277 Err(PyTypeError::new_err("Unsupported operand type for *"))
278 }
279 }
280 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
281 if let Ok(other_expr) = other.extract::<PyExpression>() {
282 Ok(PyExpression(other_expr.0 * self.0.clone()))
283 } else {
284 Err(PyTypeError::new_err("Unsupported operand type for *"))
285 }
286 }
287 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
288 if let Ok(other_expr) = other.extract::<PyExpression>() {
289 Ok(PyExpression(self.0.clone() / other_expr.0))
290 } else {
291 Err(PyTypeError::new_err("Unsupported operand type for /"))
292 }
293 }
294 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
295 if let Ok(other_expr) = other.extract::<PyExpression>() {
296 Ok(PyExpression(other_expr.0 / self.0.clone()))
297 } else {
298 Err(PyTypeError::new_err("Unsupported operand type for /"))
299 }
300 }
301 fn __neg__(&self) -> PyExpression {
302 PyExpression(-self.0.clone())
303 }
304 fn __str__(&self) -> String {
305 format!("{}", self.0)
306 }
307 fn __repr__(&self) -> String {
308 format!("{:?}", self.0)
309 }
310
311 #[new]
312 fn new() -> Self {
313 Self(Expression::create_null())
314 }
315 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
316 Ok(PyBytes::new(
317 py,
318 serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
319 .map_err(LadduError::PickleError)?
320 .as_slice(),
321 ))
322 }
323 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
324 *self = Self(
325 serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
326 .map_err(LadduError::PickleError)?,
327 );
328 Ok(())
329 }
330}
331
332#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
339#[derive(Clone)]
340pub struct PyEvaluator(pub Evaluator);
341
342#[pymethods]
343impl PyEvaluator {
344 #[getter]
352 fn parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
353 PyTuple::new(py, self.0.parameters())
354 }
355 #[getter]
357 fn free_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
358 PyTuple::new(py, self.0.free_parameters())
359 }
360 #[getter]
362 fn fixed_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
363 PyTuple::new(py, self.0.fixed_parameters())
364 }
365 #[getter]
367 fn n_free(&self) -> usize {
368 self.0.n_free()
369 }
370 #[getter]
372 fn n_fixed(&self) -> usize {
373 self.0.n_fixed()
374 }
375 #[getter]
377 fn n_parameters(&self) -> usize {
378 self.0.n_parameters()
379 }
380 fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
382 Ok(self.0.fix_parameter(name, value)?)
383 }
384 fn free_parameter(&self, name: &str) -> PyResult<()> {
386 Ok(self.0.free_parameter(name)?)
387 }
388 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
390 Ok(self.0.rename_parameter(old, new)?)
391 }
392 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
394 Ok(self.0.rename_parameters(&mapping)?)
395 }
396 #[pyo3(signature = (arg, *, strict=true))]
413 fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
414 if let Ok(string_arg) = arg.extract::<String>() {
415 if strict {
416 self.0.activate_strict(&string_arg)?;
417 } else {
418 self.0.activate(&string_arg);
419 }
420 } else if let Ok(list_arg) = arg.cast::<PyList>() {
421 let vec: Vec<String> = list_arg.extract()?;
422 if strict {
423 self.0.activate_many_strict(&vec)?;
424 } else {
425 self.0.activate_many(&vec);
426 }
427 } else {
428 return Err(PyTypeError::new_err(
429 "Argument must be either a string or a list of strings",
430 ));
431 }
432 Ok(())
433 }
434 fn activate_all(&self) {
437 self.0.activate_all();
438 }
439 #[pyo3(signature = (arg, *, strict=true))]
458 fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
459 if let Ok(string_arg) = arg.extract::<String>() {
460 if strict {
461 self.0.deactivate_strict(&string_arg)?;
462 } else {
463 self.0.deactivate(&string_arg);
464 }
465 } else if let Ok(list_arg) = arg.cast::<PyList>() {
466 let vec: Vec<String> = list_arg.extract()?;
467 if strict {
468 self.0.deactivate_many_strict(&vec)?;
469 } else {
470 self.0.deactivate_many(&vec);
471 }
472 } else {
473 return Err(PyTypeError::new_err(
474 "Argument must be either a string or a list of strings",
475 ));
476 }
477 Ok(())
478 }
479 fn deactivate_all(&self) {
482 self.0.deactivate_all();
483 }
484 #[pyo3(signature = (arg, *, strict=true))]
503 fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
504 if let Ok(string_arg) = arg.extract::<String>() {
505 if strict {
506 self.0.isolate_strict(&string_arg)?;
507 } else {
508 self.0.isolate(&string_arg);
509 }
510 } else if let Ok(list_arg) = arg.cast::<PyList>() {
511 let vec: Vec<String> = list_arg.extract()?;
512 if strict {
513 self.0.isolate_many_strict(&vec)?;
514 } else {
515 self.0.isolate_many(&vec);
516 }
517 } else {
518 return Err(PyTypeError::new_err(
519 "Argument must be either a string or a list of strings",
520 ));
521 }
522 Ok(())
523 }
524
525 #[getter]
527 fn active_mask(&self) -> Vec<bool> {
528 self.0.active_mask()
529 }
530
531 fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
533 self.0.set_active_mask(&mask)?;
534 Ok(())
535 }
536
537 #[getter]
539 fn compiled_expression(&self) -> PyCompiledExpression {
540 PyCompiledExpression(self.0.compiled_expression())
541 }
542
543 #[getter]
545 fn expression(&self) -> PyExpression {
546 PyExpression(self.0.expression())
547 }
548
549 #[pyo3(signature = (parameters, *, threads=None))]
571 fn evaluate<'py>(
572 &self,
573 py: Python<'py>,
574 parameters: Vec<f64>,
575 threads: Option<usize>,
576 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
577 let values = install_with_threads(threads, || self.0.evaluate(¶meters))?;
578 Ok(PyArray1::from_slice(py, &values?))
579 }
580 #[pyo3(signature = (parameters, indices, *, threads=None))]
604 fn evaluate_batch<'py>(
605 &self,
606 py: Python<'py>,
607 parameters: Vec<f64>,
608 indices: Vec<usize>,
609 threads: Option<usize>,
610 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
611 let values =
612 install_with_threads(threads, || self.0.evaluate_batch(¶meters, &indices))?;
613 Ok(PyArray1::from_slice(py, &values?))
614 }
615 #[pyo3(signature = (parameters, *, threads=None))]
638 fn evaluate_gradient<'py>(
639 &self,
640 py: Python<'py>,
641 parameters: Vec<f64>,
642 threads: Option<usize>,
643 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
644 let gradients: LadduResult<_> = install_with_threads(threads, || {
645 Ok(self
646 .0
647 .evaluate_gradient(¶meters)?
648 .iter()
649 .map(|grad| grad.data.as_vec().to_vec())
650 .collect::<Vec<Vec<Complex64>>>())
651 })?;
652 Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
653 }
654 #[pyo3(signature = (parameters, indices, *, threads=None))]
679 fn evaluate_gradient_batch<'py>(
680 &self,
681 py: Python<'py>,
682 parameters: Vec<f64>,
683 indices: Vec<usize>,
684 threads: Option<usize>,
685 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
686 let gradients: LadduResult<_> = install_with_threads(threads, || {
687 Ok(self
688 .0
689 .evaluate_gradient_batch(¶meters, &indices)?
690 .iter()
691 .map(|grad| grad.data.as_vec().to_vec())
692 .collect::<Vec<Vec<Complex64>>>())
693 })?;
694 Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
695 }
696}
697
698#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
705#[derive(Clone)]
706pub struct PyCompiledExpression(pub CompiledExpression);
707
708#[pymethods]
709impl PyCompiledExpression {
710 fn __str__(&self) -> String {
711 format!("{}", self.0)
712 }
713 fn __repr__(&self) -> String {
714 format!("{:?}", self.0)
715 }
716}
717
718#[pyclass(name = "Parameter", module = "laddu", from_py_object)]
719#[derive(Clone)]
720pub struct PyParameter(pub Parameter);
721
722#[pymethods]
723impl PyParameter {
724 #[getter]
725 fn name(&self) -> String {
726 self.0.name()
727 }
728 #[getter]
729 fn fixed(&self) -> Option<f64> {
730 self.0.fixed()
731 }
732 #[getter]
733 fn initial(&self) -> Option<f64> {
734 self.0.initial()
735 }
736 #[getter]
737 fn bounds(&self) -> (Option<f64>, Option<f64>) {
738 self.0.bounds()
739 }
740 #[getter]
741 fn unit(&self) -> Option<String> {
742 self.0.unit()
743 }
744 #[getter]
745 fn latex(&self) -> Option<String> {
746 self.0.latex()
747 }
748 #[getter]
749 fn description(&self) -> Option<String> {
750 self.0.description()
751 }
752}
753
754#[pyfunction(name = "parameter", signature = (name, fixed=None, *, initial=None, bounds=(None, None), unit=None, latex=None, description=None))]
786pub fn py_parameter(
787 name: &str,
788 fixed: Option<f64>,
789 initial: Option<f64>,
790 bounds: (Option<f64>, Option<f64>),
791 unit: Option<&str>,
792 latex: Option<&str>,
793 description: Option<&str>,
794) -> PyParameter {
795 let par = Parameter::new(name);
796 if let Some(value) = initial {
797 par.set_initial(value);
798 }
799 if let Some(value) = fixed {
800 par.set_fixed_value(Some(value)); }
802 par.set_bounds(bounds.0, bounds.1);
803 if let Some(unit) = unit {
804 par.set_unit(unit);
805 }
806 if let Some(latex) = latex {
807 par.set_latex(latex);
808 }
809 if let Some(description) = description {
810 par.set_description(description);
811 }
812 PyParameter(par)
813}
814
815#[pyfunction(name = "TestAmplitude")]
817pub fn py_test_amplitude(name: &str, re: PyParameter, im: PyParameter) -> PyResult<PyExpression> {
818 Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
819}