1use crate::data::PyDataset;
2use laddu_core::{
3 amplitudes::{
4 constant, parameter, Amplitude, AmplitudeID, Evaluator, Expression, Manager, Model,
5 ParameterLike,
6 },
7 traits::ReadWrite,
8 Complex, Float, LadduError,
9};
10use numpy::{PyArray1, PyArray2};
11use pyo3::{
12 exceptions::PyTypeError,
13 prelude::*,
14 types::{PyBytes, PyList},
15};
16#[cfg(feature = "rayon")]
17use rayon::ThreadPoolBuilder;
18
19#[pyclass(name = "AmplitudeID", module = "laddu")]
26#[derive(Clone)]
27pub struct PyAmplitudeID(AmplitudeID);
28
29#[pyclass(name = "Expression", module = "laddu")]
32#[derive(Clone)]
33pub struct PyExpression(Expression);
34
35#[pyfunction(name = "amplitude_sum")]
38pub fn py_amplitude_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
39 if terms.is_empty() {
40 return Ok(PyExpression(Expression::Zero));
41 }
42 if terms.len() == 1 {
43 let term = &terms[0];
44 if let Ok(expression) = term.extract::<PyExpression>() {
45 return Ok(expression);
46 }
47 if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
48 return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
49 }
50 return Err(PyTypeError::new_err(
51 "Item is neither a PyExpression nor a PyAmplitudeID",
52 ));
53 }
54 let mut iter = terms.iter();
55 let Some(first_term) = iter.next() else {
56 return Ok(PyExpression(Expression::Zero));
57 };
58 if let Ok(first_expression) = first_term.extract::<PyExpression>() {
59 let mut summation = first_expression.clone();
60 for term in iter {
61 summation = summation.__add__(term)?;
62 }
63 return Ok(summation);
64 }
65 if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
66 let mut summation = PyExpression(Expression::Amp(first_amplitude_id.0));
67 for term in iter {
68 summation = summation.__add__(term)?;
69 }
70 return Ok(summation);
71 }
72 Err(PyTypeError::new_err(
73 "Elements must be PyExpression or PyAmplitudeID",
74 ))
75}
76
77#[pyfunction(name = "amplitude_product")]
80pub fn py_amplitude_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
81 if terms.is_empty() {
82 return Ok(PyExpression(Expression::One));
83 }
84 if terms.len() == 1 {
85 let term = &terms[0];
86 if let Ok(expression) = term.extract::<PyExpression>() {
87 return Ok(expression);
88 }
89 if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
90 return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
91 }
92 return Err(PyTypeError::new_err(
93 "Item is neither a PyExpression nor a PyAmplitudeID",
94 ));
95 }
96 let mut iter = terms.iter();
97 let Some(first_term) = iter.next() else {
98 return Ok(PyExpression(Expression::One));
99 };
100 if let Ok(first_expression) = first_term.extract::<PyExpression>() {
101 let mut product = first_expression.clone();
102 for term in iter {
103 product = product.__mul__(term)?;
104 }
105 return Ok(product);
106 }
107 if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
108 let mut product = PyExpression(Expression::Amp(first_amplitude_id.0));
109 for term in iter {
110 product = product.__mul__(term)?;
111 }
112 return Ok(product);
113 }
114 Err(PyTypeError::new_err(
115 "Elements must be PyExpression or PyAmplitudeID",
116 ))
117}
118
119#[pyfunction(name = "AmplitudeZero")]
122pub fn py_amplitude_zero() -> PyExpression {
123 PyExpression(Expression::Zero)
124}
125
126#[pyfunction(name = "AmplitudeOne")]
129pub fn py_amplitude_one() -> PyExpression {
130 PyExpression(Expression::One)
131}
132
133#[pymethods]
134impl PyAmplitudeID {
135 fn real(&self) -> PyExpression {
143 PyExpression(self.0.real())
144 }
145 fn imag(&self) -> PyExpression {
153 PyExpression(self.0.imag())
154 }
155 fn norm_sqr(&self) -> PyExpression {
165 PyExpression(self.0.norm_sqr())
166 }
167 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
168 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
169 Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
170 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
171 Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
172 } else if let Ok(other_int) = other.extract::<usize>() {
173 if other_int == 0 {
174 Ok(PyExpression(Expression::Amp(self.0.clone())))
175 } else {
176 Err(PyTypeError::new_err(
177 "Addition with an integer for this type is only defined for 0",
178 ))
179 }
180 } else {
181 Err(PyTypeError::new_err("Unsupported operand type for +"))
182 }
183 }
184 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
185 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
186 Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
187 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
188 Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
189 } else if let Ok(other_int) = other.extract::<usize>() {
190 if other_int == 0 {
191 Ok(PyExpression(Expression::Amp(self.0.clone())))
192 } else {
193 Err(PyTypeError::new_err(
194 "Addition with an integer for this type is only defined for 0",
195 ))
196 }
197 } else {
198 Err(PyTypeError::new_err("Unsupported operand type for +"))
199 }
200 }
201 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
202 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
203 Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
204 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
205 Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
206 } else {
207 Err(PyTypeError::new_err("Unsupported operand type for *"))
208 }
209 }
210 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
211 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
212 Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
213 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
214 Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
215 } else {
216 Err(PyTypeError::new_err("Unsupported operand type for *"))
217 }
218 }
219 fn __str__(&self) -> String {
220 format!("{}", self.0)
221 }
222 fn __repr__(&self) -> String {
223 format!("{:?}", self.0)
224 }
225}
226
227#[pymethods]
228impl PyExpression {
229 fn real(&self) -> PyExpression {
237 PyExpression(self.0.real())
238 }
239 fn imag(&self) -> PyExpression {
247 PyExpression(self.0.imag())
248 }
249 fn norm_sqr(&self) -> PyExpression {
259 PyExpression(self.0.norm_sqr())
260 }
261 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
262 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
263 Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
264 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
265 Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
266 } else if let Ok(other_int) = other.extract::<usize>() {
267 if other_int == 0 {
268 Ok(PyExpression(self.0.clone()))
269 } else {
270 Err(PyTypeError::new_err(
271 "Addition with an integer for this type is only defined for 0",
272 ))
273 }
274 } else {
275 Err(PyTypeError::new_err("Unsupported operand type for +"))
276 }
277 }
278 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
279 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
280 Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
281 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
282 Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
283 } else if let Ok(other_int) = other.extract::<usize>() {
284 if other_int == 0 {
285 Ok(PyExpression(self.0.clone()))
286 } else {
287 Err(PyTypeError::new_err(
288 "Addition with an integer for this type is only defined for 0",
289 ))
290 }
291 } else {
292 Err(PyTypeError::new_err("Unsupported operand type for +"))
293 }
294 }
295 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
296 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
297 Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
298 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
299 Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
300 } else {
301 Err(PyTypeError::new_err("Unsupported operand type for *"))
302 }
303 }
304 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
305 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
306 Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
307 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
308 Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
309 } else {
310 Err(PyTypeError::new_err("Unsupported operand type for *"))
311 }
312 }
313 fn __str__(&self) -> String {
314 format!("{}", self.0)
315 }
316 fn __repr__(&self) -> String {
317 format!("{:?}", self.0)
318 }
319}
320
321#[pyclass(name = "Manager", module = "laddu")]
324pub struct PyManager(Manager);
325
326#[pymethods]
327impl PyManager {
328 #[new]
329 fn new() -> Self {
330 Self(Manager::default())
331 }
332 #[getter]
340 fn parameters(&self) -> Vec<String> {
341 self.0.parameters()
342 }
343 fn register(&mut self, amplitude: &PyAmplitude) -> PyResult<PyAmplitudeID> {
362 Ok(PyAmplitudeID(self.0.register(amplitude.0.clone())?))
363 }
364 fn model(&self, expression: &Bound<'_, PyAny>) -> PyResult<PyModel> {
390 let expression = if let Ok(expression) = expression.extract::<PyExpression>() {
391 Ok(expression.0)
392 } else if let Ok(aid) = expression.extract::<PyAmplitudeID>() {
393 Ok(Expression::Amp(aid.0))
394 } else {
395 Err(PyTypeError::new_err(
396 "'expression' must either by an Expression or AmplitudeID",
397 ))
398 }?;
399 Ok(PyModel(self.0.model(&expression)))
400 }
401}
402
403#[pyclass(name = "Model", module = "laddu")]
406pub struct PyModel(pub Model);
407
408#[pymethods]
409impl PyModel {
410 #[getter]
418 fn parameters(&self) -> Vec<String> {
419 self.0.parameters()
420 }
421 fn load(&self, dataset: &PyDataset) -> PyEvaluator {
442 PyEvaluator(self.0.load(&dataset.0))
443 }
444 fn save_as(&self, path: &str) -> PyResult<()> {
457 self.0.save_as(path)?;
458 Ok(())
459 }
460 #[staticmethod]
478 fn load_from(path: &str) -> PyResult<Self> {
479 Ok(PyModel(Model::load_from(path)?))
480 }
481 #[new]
482 fn new() -> Self {
483 PyModel(Model::create_null())
484 }
485 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
486 Ok(PyBytes::new(
487 py,
488 bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
489 .map_err(LadduError::EncodeError)?
490 .as_slice(),
491 ))
492 }
493 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
494 *self = PyModel(
495 bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
496 .map_err(LadduError::DecodeError)?
497 .0,
498 );
499 Ok(())
500 }
501}
502
503#[pyclass(name = "Amplitude", module = "laddu")]
510pub struct PyAmplitude(pub Box<dyn Amplitude>);
511
512#[pyclass(name = "Evaluator", module = "laddu")]
519#[derive(Clone)]
520pub struct PyEvaluator(pub Evaluator);
521
522#[pymethods]
523impl PyEvaluator {
524 #[getter]
532 fn parameters(&self) -> Vec<String> {
533 self.0.parameters()
534 }
535 fn activate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
550 if let Ok(string_arg) = arg.extract::<String>() {
551 self.0.activate(&string_arg)?;
552 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
553 let vec: Vec<String> = list_arg.extract()?;
554 self.0.activate_many(&vec)?;
555 } else {
556 return Err(PyTypeError::new_err(
557 "Argument must be either a string or a list of strings",
558 ));
559 }
560 Ok(())
561 }
562 fn activate_all(&self) {
565 self.0.activate_all();
566 }
567 fn deactivate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
584 if let Ok(string_arg) = arg.extract::<String>() {
585 self.0.deactivate(&string_arg)?;
586 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
587 let vec: Vec<String> = list_arg.extract()?;
588 self.0.deactivate_many(&vec)?;
589 } else {
590 return Err(PyTypeError::new_err(
591 "Argument must be either a string or a list of strings",
592 ));
593 }
594 Ok(())
595 }
596 fn deactivate_all(&self) {
599 self.0.deactivate_all();
600 }
601 fn isolate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
618 if let Ok(string_arg) = arg.extract::<String>() {
619 self.0.isolate(&string_arg)?;
620 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
621 let vec: Vec<String> = list_arg.extract()?;
622 self.0.isolate_many(&vec)?;
623 } else {
624 return Err(PyTypeError::new_err(
625 "Argument must be either a string or a list of strings",
626 ));
627 }
628 Ok(())
629 }
630 #[pyo3(signature = (parameters, *, threads=None))]
650 fn evaluate<'py>(
651 &self,
652 py: Python<'py>,
653 parameters: Vec<Float>,
654 threads: Option<usize>,
655 ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
656 #[cfg(feature = "rayon")]
657 {
658 Ok(PyArray1::from_slice(
659 py,
660 &ThreadPoolBuilder::new()
661 .num_threads(threads.unwrap_or_else(num_cpus::get))
662 .build()
663 .map_err(LadduError::from)?
664 .install(|| self.0.evaluate(¶meters)),
665 ))
666 }
667 #[cfg(not(feature = "rayon"))]
668 {
669 Ok(PyArray1::from_slice(py, &self.0.evaluate(¶meters)))
670 }
671 }
672 #[pyo3(signature = (parameters, *, threads=None))]
693 fn evaluate_gradient<'py>(
694 &self,
695 py: Python<'py>,
696 parameters: Vec<Float>,
697 threads: Option<usize>,
698 ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
699 #[cfg(feature = "rayon")]
700 {
701 Ok(PyArray2::from_vec2(
702 py,
703 &ThreadPoolBuilder::new()
704 .num_threads(threads.unwrap_or_else(num_cpus::get))
705 .build()
706 .map_err(LadduError::from)?
707 .install(|| {
708 self.0
709 .evaluate_gradient(¶meters)
710 .iter()
711 .map(|grad| grad.data.as_vec().to_vec())
712 .collect::<Vec<Vec<Complex<Float>>>>()
713 }),
714 )
715 .map_err(LadduError::NumpyError)?)
716 }
717 #[cfg(not(feature = "rayon"))]
718 {
719 Ok(PyArray2::from_vec2(
720 py,
721 &self
722 .0
723 .evaluate_gradient(¶meters)
724 .iter()
725 .map(|grad| grad.data.as_vec().to_vec())
726 .collect::<Vec<Vec<Complex<Float>>>>(),
727 )
728 .map_err(LadduError::NumpyError)?)
729 }
730 }
731}
732
733#[pyclass(name = "ParameterLike", module = "laddu")]
742#[derive(Clone)]
743pub struct PyParameterLike(pub ParameterLike);
744
745#[pyfunction(name = "parameter")]
762pub fn py_parameter(name: &str) -> PyParameterLike {
763 PyParameterLike(parameter(name))
764}
765
766#[pyfunction(name = "constant")]
779pub fn py_constant(value: Float) -> PyParameterLike {
780 PyParameterLike(constant(value))
781}