1use crate::data::PyDataset;
2use bincode::{deserialize, serialize};
3use laddu_core::{
4 amplitudes::{
5 constant, parameter, Amplitude, AmplitudeID, Evaluator, Expression, Manager, Model,
6 ParameterLike,
7 },
8 traits::ReadWrite,
9 Complex, Float, LadduError,
10};
11use numpy::{PyArray1, PyArray2};
12use pyo3::{
13 exceptions::PyTypeError,
14 prelude::*,
15 types::{PyBytes, PyList},
16};
17#[cfg(feature = "rayon")]
18use rayon::ThreadPoolBuilder;
19
20#[pyclass(name = "AmplitudeID", module = "laddu")]
27#[derive(Clone)]
28pub struct PyAmplitudeID(AmplitudeID);
29
30#[pyclass(name = "Expression", module = "laddu")]
33#[derive(Clone)]
34pub struct PyExpression(Expression);
35
36#[pyfunction(name = "amplitude_sum")]
39pub fn py_amplitude_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
40 if terms.is_empty() {
41 return Ok(PyExpression(Expression::Zero));
42 }
43 if terms.len() == 1 {
44 let term = &terms[0];
45 if let Ok(expression) = term.extract::<PyExpression>() {
46 return Ok(expression);
47 }
48 if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
49 return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
50 }
51 return Err(PyTypeError::new_err(
52 "Item is neither a PyExpression nor a PyAmplitudeID",
53 ));
54 }
55 let mut iter = terms.iter();
56 let Some(first_term) = iter.next() else {
57 return Ok(PyExpression(Expression::Zero));
58 };
59 if let Ok(first_expression) = first_term.extract::<PyExpression>() {
60 let mut summation = first_expression.clone();
61 for term in iter {
62 summation = summation.__add__(term)?;
63 }
64 return Ok(summation);
65 }
66 if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
67 let mut summation = PyExpression(Expression::Amp(first_amplitude_id.0));
68 for term in iter {
69 summation = summation.__add__(term)?;
70 }
71 return Ok(summation);
72 }
73 Err(PyTypeError::new_err(
74 "Elements must be PyExpression or PyAmplitudeID",
75 ))
76}
77
78#[pyfunction(name = "amplitude_product")]
81pub fn py_amplitude_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
82 if terms.is_empty() {
83 return Ok(PyExpression(Expression::One));
84 }
85 if terms.len() == 1 {
86 let term = &terms[0];
87 if let Ok(expression) = term.extract::<PyExpression>() {
88 return Ok(expression);
89 }
90 if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
91 return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
92 }
93 return Err(PyTypeError::new_err(
94 "Item is neither a PyExpression nor a PyAmplitudeID",
95 ));
96 }
97 let mut iter = terms.iter();
98 let Some(first_term) = iter.next() else {
99 return Ok(PyExpression(Expression::One));
100 };
101 if let Ok(first_expression) = first_term.extract::<PyExpression>() {
102 let mut product = first_expression.clone();
103 for term in iter {
104 product = product.__mul__(term)?;
105 }
106 return Ok(product);
107 }
108 if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
109 let mut product = PyExpression(Expression::Amp(first_amplitude_id.0));
110 for term in iter {
111 product = product.__mul__(term)?;
112 }
113 return Ok(product);
114 }
115 Err(PyTypeError::new_err(
116 "Elements must be PyExpression or PyAmplitudeID",
117 ))
118}
119
120#[pyfunction(name = "AmplitudeZero")]
123pub fn py_amplitude_zero() -> PyExpression {
124 PyExpression(Expression::Zero)
125}
126
127#[pyfunction(name = "AmplitudeOne")]
130pub fn py_amplitude_one() -> PyExpression {
131 PyExpression(Expression::One)
132}
133
134#[pymethods]
135impl PyAmplitudeID {
136 fn real(&self) -> PyExpression {
144 PyExpression(self.0.real())
145 }
146 fn imag(&self) -> PyExpression {
154 PyExpression(self.0.imag())
155 }
156 fn norm_sqr(&self) -> PyExpression {
166 PyExpression(self.0.norm_sqr())
167 }
168 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
169 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
170 Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
171 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
172 Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
173 } else if let Ok(other_int) = other.extract::<usize>() {
174 if other_int == 0 {
175 Ok(PyExpression(Expression::Amp(self.0.clone())))
176 } else {
177 Err(PyTypeError::new_err(
178 "Addition with an integer for this type is only defined for 0",
179 ))
180 }
181 } else {
182 Err(PyTypeError::new_err("Unsupported operand type for +"))
183 }
184 }
185 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
186 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
187 Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
188 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
189 Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
190 } else if let Ok(other_int) = other.extract::<usize>() {
191 if other_int == 0 {
192 Ok(PyExpression(Expression::Amp(self.0.clone())))
193 } else {
194 Err(PyTypeError::new_err(
195 "Addition with an integer for this type is only defined for 0",
196 ))
197 }
198 } else {
199 Err(PyTypeError::new_err("Unsupported operand type for +"))
200 }
201 }
202 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
203 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
204 Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
205 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
206 Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
207 } else {
208 Err(PyTypeError::new_err("Unsupported operand type for *"))
209 }
210 }
211 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
212 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
213 Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
214 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
215 Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
216 } else {
217 Err(PyTypeError::new_err("Unsupported operand type for *"))
218 }
219 }
220 fn __str__(&self) -> String {
221 format!("{}", self.0)
222 }
223 fn __repr__(&self) -> String {
224 format!("{:?}", self.0)
225 }
226}
227
228#[pymethods]
229impl PyExpression {
230 fn real(&self) -> PyExpression {
238 PyExpression(self.0.real())
239 }
240 fn imag(&self) -> PyExpression {
248 PyExpression(self.0.imag())
249 }
250 fn norm_sqr(&self) -> PyExpression {
260 PyExpression(self.0.norm_sqr())
261 }
262 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
263 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
264 Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
265 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
266 Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
267 } else if let Ok(other_int) = other.extract::<usize>() {
268 if other_int == 0 {
269 Ok(PyExpression(self.0.clone()))
270 } else {
271 Err(PyTypeError::new_err(
272 "Addition with an integer for this type is only defined for 0",
273 ))
274 }
275 } else {
276 Err(PyTypeError::new_err("Unsupported operand type for +"))
277 }
278 }
279 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
280 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
281 Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
282 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
283 Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
284 } else if let Ok(other_int) = other.extract::<usize>() {
285 if other_int == 0 {
286 Ok(PyExpression(self.0.clone()))
287 } else {
288 Err(PyTypeError::new_err(
289 "Addition with an integer for this type is only defined for 0",
290 ))
291 }
292 } else {
293 Err(PyTypeError::new_err("Unsupported operand type for +"))
294 }
295 }
296 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
297 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
298 Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
299 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
300 Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
301 } else {
302 Err(PyTypeError::new_err("Unsupported operand type for *"))
303 }
304 }
305 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
306 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
307 Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
308 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
309 Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
310 } else {
311 Err(PyTypeError::new_err("Unsupported operand type for *"))
312 }
313 }
314 fn __str__(&self) -> String {
315 format!("{}", self.0)
316 }
317 fn __repr__(&self) -> String {
318 format!("{:?}", self.0)
319 }
320}
321
322#[pyclass(name = "Manager", module = "laddu")]
325pub struct PyManager(Manager);
326
327#[pymethods]
328impl PyManager {
329 #[new]
330 fn new() -> Self {
331 Self(Manager::default())
332 }
333 #[getter]
341 fn parameters(&self) -> Vec<String> {
342 self.0.parameters()
343 }
344 fn register(&mut self, amplitude: &PyAmplitude) -> PyResult<PyAmplitudeID> {
363 Ok(PyAmplitudeID(self.0.register(amplitude.0.clone())?))
364 }
365 fn model(&self, expression: &Bound<'_, PyAny>) -> PyResult<PyModel> {
391 let expression = if let Ok(expression) = expression.extract::<PyExpression>() {
392 Ok(expression.0)
393 } else if let Ok(aid) = expression.extract::<PyAmplitudeID>() {
394 Ok(Expression::Amp(aid.0))
395 } else {
396 Err(PyTypeError::new_err(
397 "'expression' must either by an Expression or AmplitudeID",
398 ))
399 }?;
400 Ok(PyModel(self.0.model(&expression)))
401 }
402}
403
404#[pyclass(name = "Model", module = "laddu")]
407pub struct PyModel(pub Model);
408
409#[pymethods]
410impl PyModel {
411 #[getter]
419 fn parameters(&self) -> Vec<String> {
420 self.0.parameters()
421 }
422 fn load(&self, dataset: &PyDataset) -> PyEvaluator {
443 PyEvaluator(self.0.load(&dataset.0))
444 }
445 fn save_as(&self, path: &str) -> PyResult<()> {
458 self.0.save_as(path)?;
459 Ok(())
460 }
461 #[staticmethod]
479 fn load_from(path: &str) -> PyResult<Self> {
480 Ok(PyModel(Model::load_from(path)?))
481 }
482 #[new]
483 fn new() -> Self {
484 PyModel(Model::create_null())
485 }
486 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
487 Ok(PyBytes::new(
488 py,
489 serialize(&self.0)
490 .map_err(LadduError::SerdeError)?
491 .as_slice(),
492 ))
493 }
494 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
495 *self = PyModel(deserialize(state.as_bytes()).map_err(LadduError::SerdeError)?);
496 Ok(())
497 }
498}
499
500#[pyclass(name = "Amplitude", module = "laddu")]
507pub struct PyAmplitude(pub Box<dyn Amplitude>);
508
509#[pyclass(name = "Evaluator", module = "laddu")]
516#[derive(Clone)]
517pub struct PyEvaluator(pub Evaluator);
518
519#[pymethods]
520impl PyEvaluator {
521 #[getter]
529 fn parameters(&self) -> Vec<String> {
530 self.0.parameters()
531 }
532 fn activate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
547 if let Ok(string_arg) = arg.extract::<String>() {
548 self.0.activate(&string_arg)?;
549 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
550 let vec: Vec<String> = list_arg.extract()?;
551 self.0.activate_many(&vec)?;
552 } else {
553 return Err(PyTypeError::new_err(
554 "Argument must be either a string or a list of strings",
555 ));
556 }
557 Ok(())
558 }
559 fn activate_all(&self) {
562 self.0.activate_all();
563 }
564 fn deactivate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
581 if let Ok(string_arg) = arg.extract::<String>() {
582 self.0.deactivate(&string_arg)?;
583 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
584 let vec: Vec<String> = list_arg.extract()?;
585 self.0.deactivate_many(&vec)?;
586 } else {
587 return Err(PyTypeError::new_err(
588 "Argument must be either a string or a list of strings",
589 ));
590 }
591 Ok(())
592 }
593 fn deactivate_all(&self) {
596 self.0.deactivate_all();
597 }
598 fn isolate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
615 if let Ok(string_arg) = arg.extract::<String>() {
616 self.0.isolate(&string_arg)?;
617 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
618 let vec: Vec<String> = list_arg.extract()?;
619 self.0.isolate_many(&vec)?;
620 } else {
621 return Err(PyTypeError::new_err(
622 "Argument must be either a string or a list of strings",
623 ));
624 }
625 Ok(())
626 }
627 #[pyo3(signature = (parameters, *, threads=None))]
647 fn evaluate<'py>(
648 &self,
649 py: Python<'py>,
650 parameters: Vec<Float>,
651 threads: Option<usize>,
652 ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
653 #[cfg(feature = "rayon")]
654 {
655 Ok(PyArray1::from_slice(
656 py,
657 &ThreadPoolBuilder::new()
658 .num_threads(threads.unwrap_or_else(num_cpus::get))
659 .build()
660 .map_err(LadduError::from)?
661 .install(|| self.0.evaluate(¶meters)),
662 ))
663 }
664 #[cfg(not(feature = "rayon"))]
665 {
666 Ok(PyArray1::from_slice(py, &self.0.evaluate(¶meters)))
667 }
668 }
669 #[pyo3(signature = (parameters, *, threads=None))]
690 fn evaluate_gradient<'py>(
691 &self,
692 py: Python<'py>,
693 parameters: Vec<Float>,
694 threads: Option<usize>,
695 ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
696 #[cfg(feature = "rayon")]
697 {
698 Ok(PyArray2::from_vec2(
699 py,
700 &ThreadPoolBuilder::new()
701 .num_threads(threads.unwrap_or_else(num_cpus::get))
702 .build()
703 .map_err(LadduError::from)?
704 .install(|| {
705 self.0
706 .evaluate_gradient(¶meters)
707 .iter()
708 .map(|grad| grad.data.as_vec().to_vec())
709 .collect::<Vec<Vec<Complex<Float>>>>()
710 }),
711 )
712 .map_err(LadduError::NumpyError)?)
713 }
714 #[cfg(not(feature = "rayon"))]
715 {
716 Ok(PyArray2::from_vec2(
717 py,
718 &self
719 .0
720 .evaluate_gradient(¶meters)
721 .iter()
722 .map(|grad| grad.data.as_vec().to_vec())
723 .collect::<Vec<Vec<Complex<Float>>>>(),
724 )
725 .map_err(LadduError::NumpyError)?)
726 }
727 }
728}
729
730#[pyclass(name = "ParameterLike", module = "laddu")]
739#[derive(Clone)]
740pub struct PyParameterLike(pub ParameterLike);
741
742#[pyfunction(name = "parameter")]
759pub fn py_parameter(name: &str) -> PyParameterLike {
760 PyParameterLike(parameter(name))
761}
762
763#[pyfunction(name = "constant")]
776pub fn py_constant(value: Float) -> PyParameterLike {
777 PyParameterLike(constant(value))
778}