1use crate::data::PyDataset;
2use laddu_core::{
3 amplitudes::{
4 constant, parameter, Amplitude, AmplitudeID, Evaluator, Expression, Manager, Model,
5 ParameterLike,
6 },
7 traits::ReadWrite,
8 Complex, Float, LadduError,
9};
10use numpy::{PyArray1, PyArray2};
11use pyo3::{
12 exceptions::PyTypeError,
13 prelude::*,
14 types::{PyBytes, PyList},
15};
16#[cfg(feature = "rayon")]
17use rayon::ThreadPoolBuilder;
18
19#[pyclass(name = "AmplitudeID", module = "laddu")]
26#[derive(Clone)]
27pub struct PyAmplitudeID(AmplitudeID);
28
29#[pyclass(name = "Expression", module = "laddu")]
32#[derive(Clone)]
33pub struct PyExpression(Expression);
34
35#[pyfunction(name = "amplitude_sum")]
38pub fn py_amplitude_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
39 if terms.is_empty() {
40 return Ok(PyExpression(Expression::Zero));
41 }
42 if terms.len() == 1 {
43 let term = &terms[0];
44 if let Ok(expression) = term.extract::<PyExpression>() {
45 return Ok(expression);
46 }
47 if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
48 return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
49 }
50 return Err(PyTypeError::new_err(
51 "Item is neither a PyExpression nor a PyAmplitudeID",
52 ));
53 }
54 let mut iter = terms.iter();
55 let Some(first_term) = iter.next() else {
56 return Ok(PyExpression(Expression::Zero));
57 };
58 if let Ok(first_expression) = first_term.extract::<PyExpression>() {
59 let mut summation = first_expression.clone();
60 for term in iter {
61 summation = summation.__add__(term)?;
62 }
63 return Ok(summation);
64 }
65 if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
66 let mut summation = PyExpression(Expression::Amp(first_amplitude_id.0));
67 for term in iter {
68 summation = summation.__add__(term)?;
69 }
70 return Ok(summation);
71 }
72 Err(PyTypeError::new_err(
73 "Elements must be PyExpression or PyAmplitudeID",
74 ))
75}
76
77#[pyfunction(name = "amplitude_product")]
80pub fn py_amplitude_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
81 if terms.is_empty() {
82 return Ok(PyExpression(Expression::One));
83 }
84 if terms.len() == 1 {
85 let term = &terms[0];
86 if let Ok(expression) = term.extract::<PyExpression>() {
87 return Ok(expression);
88 }
89 if let Ok(py_amplitude_id) = term.extract::<PyAmplitudeID>() {
90 return Ok(PyExpression(Expression::Amp(py_amplitude_id.0)));
91 }
92 return Err(PyTypeError::new_err(
93 "Item is neither a PyExpression nor a PyAmplitudeID",
94 ));
95 }
96 let mut iter = terms.iter();
97 let Some(first_term) = iter.next() else {
98 return Ok(PyExpression(Expression::One));
99 };
100 if let Ok(first_expression) = first_term.extract::<PyExpression>() {
101 let mut product = first_expression.clone();
102 for term in iter {
103 product = product.__mul__(term)?;
104 }
105 return Ok(product);
106 }
107 if let Ok(first_amplitude_id) = first_term.extract::<PyAmplitudeID>() {
108 let mut product = PyExpression(Expression::Amp(first_amplitude_id.0));
109 for term in iter {
110 product = product.__mul__(term)?;
111 }
112 return Ok(product);
113 }
114 Err(PyTypeError::new_err(
115 "Elements must be PyExpression or PyAmplitudeID",
116 ))
117}
118
119#[pyfunction(name = "AmplitudeZero")]
122pub fn py_amplitude_zero() -> PyExpression {
123 PyExpression(Expression::Zero)
124}
125
126#[pyfunction(name = "AmplitudeOne")]
129pub fn py_amplitude_one() -> PyExpression {
130 PyExpression(Expression::One)
131}
132
133#[pymethods]
134impl PyAmplitudeID {
135 fn real(&self) -> PyExpression {
143 PyExpression(self.0.real())
144 }
145 fn imag(&self) -> PyExpression {
153 PyExpression(self.0.imag())
154 }
155 fn conj(&self) -> PyExpression {
163 PyExpression(self.0.conj())
164 }
165 fn norm_sqr(&self) -> PyExpression {
175 PyExpression(self.0.norm_sqr())
176 }
177 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
178 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
179 Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
180 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
181 Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
182 } else if let Ok(other_int) = other.extract::<usize>() {
183 if other_int == 0 {
184 Ok(PyExpression(Expression::Amp(self.0.clone())))
185 } else {
186 Err(PyTypeError::new_err(
187 "Addition with an integer for this type is only defined for 0",
188 ))
189 }
190 } else {
191 Err(PyTypeError::new_err("Unsupported operand type for +"))
192 }
193 }
194 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
195 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
196 Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
197 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
198 Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
199 } else if let Ok(other_int) = other.extract::<usize>() {
200 if other_int == 0 {
201 Ok(PyExpression(Expression::Amp(self.0.clone())))
202 } else {
203 Err(PyTypeError::new_err(
204 "Addition with an integer for this type is only defined for 0",
205 ))
206 }
207 } else {
208 Err(PyTypeError::new_err("Unsupported operand type for +"))
209 }
210 }
211 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
212 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
213 Ok(PyExpression(self.0.clone() - other_aid.0.clone()))
214 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
215 Ok(PyExpression(self.0.clone() - other_expr.0.clone()))
216 } else {
217 Err(PyTypeError::new_err("Unsupported operand type for -"))
218 }
219 }
220 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
221 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
222 Ok(PyExpression(other_aid.0.clone() - self.0.clone()))
223 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
224 Ok(PyExpression(other_expr.0.clone() - self.0.clone()))
225 } else {
226 Err(PyTypeError::new_err("Unsupported operand type for -"))
227 }
228 }
229 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
230 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
231 Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
232 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
233 Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
234 } else {
235 Err(PyTypeError::new_err("Unsupported operand type for *"))
236 }
237 }
238 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
239 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
240 Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
241 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
242 Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
243 } else {
244 Err(PyTypeError::new_err("Unsupported operand type for *"))
245 }
246 }
247 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
248 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
249 Ok(PyExpression(self.0.clone() / other_aid.0.clone()))
250 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
251 Ok(PyExpression(self.0.clone() / other_expr.0.clone()))
252 } else {
253 Err(PyTypeError::new_err("Unsupported operand type for /"))
254 }
255 }
256 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
257 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
258 Ok(PyExpression(other_aid.0.clone() / self.0.clone()))
259 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
260 Ok(PyExpression(other_expr.0.clone() / self.0.clone()))
261 } else {
262 Err(PyTypeError::new_err("Unsupported operand type for /"))
263 }
264 }
265 fn __neg__(&self) -> PyExpression {
266 PyExpression(-self.0.clone())
267 }
268 fn __str__(&self) -> String {
269 format!("{}", self.0)
270 }
271 fn __repr__(&self) -> String {
272 format!("{:?}", self.0)
273 }
274}
275
276#[pymethods]
277impl PyExpression {
278 fn real(&self) -> PyExpression {
286 PyExpression(self.0.real())
287 }
288 fn imag(&self) -> PyExpression {
296 PyExpression(self.0.imag())
297 }
298 fn conj(&self) -> PyExpression {
306 PyExpression(self.0.conj())
307 }
308 fn norm_sqr(&self) -> PyExpression {
318 PyExpression(self.0.norm_sqr())
319 }
320 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
321 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
322 Ok(PyExpression(self.0.clone() + other_aid.0.clone()))
323 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
324 Ok(PyExpression(self.0.clone() + other_expr.0.clone()))
325 } else if let Ok(other_int) = other.extract::<usize>() {
326 if other_int == 0 {
327 Ok(PyExpression(self.0.clone()))
328 } else {
329 Err(PyTypeError::new_err(
330 "Addition with an integer for this type is only defined for 0",
331 ))
332 }
333 } else {
334 Err(PyTypeError::new_err("Unsupported operand type for +"))
335 }
336 }
337 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
338 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
339 Ok(PyExpression(other_aid.0.clone() + self.0.clone()))
340 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
341 Ok(PyExpression(other_expr.0.clone() + self.0.clone()))
342 } else if let Ok(other_int) = other.extract::<usize>() {
343 if other_int == 0 {
344 Ok(PyExpression(self.0.clone()))
345 } else {
346 Err(PyTypeError::new_err(
347 "Addition with an integer for this type is only defined for 0",
348 ))
349 }
350 } else {
351 Err(PyTypeError::new_err("Unsupported operand type for +"))
352 }
353 }
354 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
355 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
356 Ok(PyExpression(self.0.clone() - other_aid.0.clone()))
357 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
358 Ok(PyExpression(self.0.clone() - other_expr.0.clone()))
359 } else {
360 Err(PyTypeError::new_err("Unsupported operand type for -"))
361 }
362 }
363 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
364 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
365 Ok(PyExpression(other_aid.0.clone() - self.0.clone()))
366 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
367 Ok(PyExpression(other_expr.0.clone() - self.0.clone()))
368 } else {
369 Err(PyTypeError::new_err("Unsupported operand type for -"))
370 }
371 }
372 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
373 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
374 Ok(PyExpression(self.0.clone() * other_aid.0.clone()))
375 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
376 Ok(PyExpression(self.0.clone() * other_expr.0.clone()))
377 } else {
378 Err(PyTypeError::new_err("Unsupported operand type for *"))
379 }
380 }
381 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
382 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
383 Ok(PyExpression(other_aid.0.clone() * self.0.clone()))
384 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
385 Ok(PyExpression(other_expr.0.clone() * self.0.clone()))
386 } else {
387 Err(PyTypeError::new_err("Unsupported operand type for *"))
388 }
389 }
390 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
391 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
392 Ok(PyExpression(self.0.clone() / other_aid.0.clone()))
393 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
394 Ok(PyExpression(self.0.clone() / other_expr.0.clone()))
395 } else {
396 Err(PyTypeError::new_err("Unsupported operand type for /"))
397 }
398 }
399 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
400 if let Ok(other_aid) = other.extract::<PyRef<PyAmplitudeID>>() {
401 Ok(PyExpression(other_aid.0.clone() / self.0.clone()))
402 } else if let Ok(other_expr) = other.extract::<PyExpression>() {
403 Ok(PyExpression(other_expr.0.clone() / self.0.clone()))
404 } else {
405 Err(PyTypeError::new_err("Unsupported operand type for /"))
406 }
407 }
408 fn __neg__(&self) -> PyExpression {
409 PyExpression(-self.0.clone())
410 }
411 fn __str__(&self) -> String {
412 format!("{}", self.0)
413 }
414 fn __repr__(&self) -> String {
415 format!("{:?}", self.0)
416 }
417}
418
419#[pyclass(name = "Manager", module = "laddu")]
422pub struct PyManager(Manager);
423
424#[pymethods]
425impl PyManager {
426 #[new]
427 fn new() -> Self {
428 Self(Manager::default())
429 }
430 #[getter]
438 fn parameters(&self) -> Vec<String> {
439 self.0.parameters()
440 }
441 fn register(&mut self, amplitude: &PyAmplitude) -> PyResult<PyAmplitudeID> {
460 Ok(PyAmplitudeID(self.0.register(amplitude.0.clone())?))
461 }
462 fn model(&self, expression: &Bound<'_, PyAny>) -> PyResult<PyModel> {
488 let expression = if let Ok(expression) = expression.extract::<PyExpression>() {
489 Ok(expression.0)
490 } else if let Ok(aid) = expression.extract::<PyAmplitudeID>() {
491 Ok(Expression::Amp(aid.0))
492 } else {
493 Err(PyTypeError::new_err(
494 "'expression' must either by an Expression or AmplitudeID",
495 ))
496 }?;
497 Ok(PyModel(self.0.model(&expression)))
498 }
499}
500
501#[pyclass(name = "Model", module = "laddu")]
504pub struct PyModel(pub Model);
505
506#[pymethods]
507impl PyModel {
508 #[getter]
516 fn parameters(&self) -> Vec<String> {
517 self.0.parameters()
518 }
519 fn load(&self, dataset: &PyDataset) -> PyEvaluator {
540 PyEvaluator(self.0.load(&dataset.0))
541 }
542 fn save_as(&self, path: &str) -> PyResult<()> {
555 self.0.save_as(path)?;
556 Ok(())
557 }
558 #[staticmethod]
576 fn load_from(path: &str) -> PyResult<Self> {
577 Ok(PyModel(Model::load_from(path)?))
578 }
579 #[new]
580 fn new() -> Self {
581 PyModel(Model::create_null())
582 }
583 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
584 Ok(PyBytes::new(
585 py,
586 bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
587 .map_err(LadduError::EncodeError)?
588 .as_slice(),
589 ))
590 }
591 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
592 *self = PyModel(
593 bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
594 .map_err(LadduError::DecodeError)?
595 .0,
596 );
597 Ok(())
598 }
599}
600
601#[pyclass(name = "Amplitude", module = "laddu")]
608pub struct PyAmplitude(pub Box<dyn Amplitude>);
609
610#[pyclass(name = "Evaluator", module = "laddu")]
617#[derive(Clone)]
618pub struct PyEvaluator(pub Evaluator);
619
620#[pymethods]
621impl PyEvaluator {
622 #[getter]
630 fn parameters(&self) -> Vec<String> {
631 self.0.parameters()
632 }
633 fn activate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
648 if let Ok(string_arg) = arg.extract::<String>() {
649 self.0.activate(&string_arg)?;
650 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
651 let vec: Vec<String> = list_arg.extract()?;
652 self.0.activate_many(&vec)?;
653 } else {
654 return Err(PyTypeError::new_err(
655 "Argument must be either a string or a list of strings",
656 ));
657 }
658 Ok(())
659 }
660 fn activate_all(&self) {
663 self.0.activate_all();
664 }
665 fn deactivate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
682 if let Ok(string_arg) = arg.extract::<String>() {
683 self.0.deactivate(&string_arg)?;
684 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
685 let vec: Vec<String> = list_arg.extract()?;
686 self.0.deactivate_many(&vec)?;
687 } else {
688 return Err(PyTypeError::new_err(
689 "Argument must be either a string or a list of strings",
690 ));
691 }
692 Ok(())
693 }
694 fn deactivate_all(&self) {
697 self.0.deactivate_all();
698 }
699 fn isolate(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
716 if let Ok(string_arg) = arg.extract::<String>() {
717 self.0.isolate(&string_arg)?;
718 } else if let Ok(list_arg) = arg.downcast::<PyList>() {
719 let vec: Vec<String> = list_arg.extract()?;
720 self.0.isolate_many(&vec)?;
721 } else {
722 return Err(PyTypeError::new_err(
723 "Argument must be either a string or a list of strings",
724 ));
725 }
726 Ok(())
727 }
728 #[pyo3(signature = (parameters, *, threads=None))]
748 fn evaluate<'py>(
749 &self,
750 py: Python<'py>,
751 parameters: Vec<Float>,
752 threads: Option<usize>,
753 ) -> PyResult<Bound<'py, PyArray1<Complex<Float>>>> {
754 #[cfg(feature = "rayon")]
755 {
756 Ok(PyArray1::from_slice(
757 py,
758 &ThreadPoolBuilder::new()
759 .num_threads(threads.unwrap_or_else(num_cpus::get))
760 .build()
761 .map_err(LadduError::from)?
762 .install(|| self.0.evaluate(¶meters)),
763 ))
764 }
765 #[cfg(not(feature = "rayon"))]
766 {
767 Ok(PyArray1::from_slice(py, &self.0.evaluate(¶meters)))
768 }
769 }
770 #[pyo3(signature = (parameters, *, threads=None))]
791 fn evaluate_gradient<'py>(
792 &self,
793 py: Python<'py>,
794 parameters: Vec<Float>,
795 threads: Option<usize>,
796 ) -> PyResult<Bound<'py, PyArray2<Complex<Float>>>> {
797 #[cfg(feature = "rayon")]
798 {
799 Ok(PyArray2::from_vec2(
800 py,
801 &ThreadPoolBuilder::new()
802 .num_threads(threads.unwrap_or_else(num_cpus::get))
803 .build()
804 .map_err(LadduError::from)?
805 .install(|| {
806 self.0
807 .evaluate_gradient(¶meters)
808 .iter()
809 .map(|grad| grad.data.as_vec().to_vec())
810 .collect::<Vec<Vec<Complex<Float>>>>()
811 }),
812 )
813 .map_err(LadduError::NumpyError)?)
814 }
815 #[cfg(not(feature = "rayon"))]
816 {
817 Ok(PyArray2::from_vec2(
818 py,
819 &self
820 .0
821 .evaluate_gradient(¶meters)
822 .iter()
823 .map(|grad| grad.data.as_vec().to_vec())
824 .collect::<Vec<Vec<Complex<Float>>>>(),
825 )
826 .map_err(LadduError::NumpyError)?)
827 }
828 }
829}
830
831#[pyclass(name = "ParameterLike", module = "laddu")]
840#[derive(Clone)]
841pub struct PyParameterLike(pub ParameterLike);
842
843#[pyfunction(name = "parameter")]
860pub fn py_parameter(name: &str) -> PyParameterLike {
861 PyParameterLike(parameter(name))
862}
863
864#[pyfunction(name = "constant")]
877pub fn py_constant(value: Float) -> PyParameterLike {
878 PyParameterLike(constant(value))
879}