1use crate::data::PyDataset;
2use laddu_core::{
3 amplitudes::{constant, parameter, Evaluator, Expression, ParameterLike, TestAmplitude},
4 f64, LadduError, ReadWrite,
5};
6use num::complex::Complex64;
7use numpy::{PyArray1, PyArray2};
8use pyo3::{
9 exceptions::PyTypeError,
10 prelude::*,
11 types::{PyBytes, PyList},
12};
13#[cfg(feature = "rayon")]
14use rayon::ThreadPoolBuilder;
15use std::collections::HashMap;
16
17#[pyclass(name = "Expression", module = "laddu")]
20#[derive(Clone)]
21pub struct PyExpression(pub Expression);
22
23#[pyfunction(name = "expr_sum")]
26pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
27 if terms.is_empty() {
28 return Ok(PyExpression(Expression::zero()));
29 }
30 if terms.len() == 1 {
31 let term = &terms[0];
32 if let Ok(expression) = term.extract::<PyExpression>() {
33 return Ok(expression);
34 }
35 return Err(PyTypeError::new_err("Item is not a PyExpression"));
36 }
37 let mut iter = terms.iter();
38 let Some(first_term) = iter.next() else {
39 return Ok(PyExpression(Expression::zero()));
40 };
41 let PyExpression(mut summation) = first_term
42 .extract::<PyExpression>()
43 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
44 for term in iter {
45 let PyExpression(expr) = term
46 .extract::<PyExpression>()
47 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
48 summation = summation + expr;
49 }
50 Ok(PyExpression(summation))
51}
52
53#[pyfunction(name = "expr_product")]
56pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
57 if terms.is_empty() {
58 return Ok(PyExpression(Expression::one()));
59 }
60 if terms.len() == 1 {
61 let term = &terms[0];
62 if let Ok(expression) = term.extract::<PyExpression>() {
63 return Ok(expression);
64 }
65 return Err(PyTypeError::new_err("Item is not a PyExpression"));
66 }
67 let mut iter = terms.iter();
68 let Some(first_term) = iter.next() else {
69 return Ok(PyExpression(Expression::one()));
70 };
71 let PyExpression(mut product) = first_term
72 .extract::<PyExpression>()
73 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
74 for term in iter {
75 let PyExpression(expr) = term
76 .extract::<PyExpression>()
77 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
78 product = product * expr;
79 }
80 Ok(PyExpression(product))
81}
82
83#[pyfunction(name = "Zero")]
86pub fn py_expr_zero() -> PyExpression {
87 PyExpression(Expression::zero())
88}
89
90#[pyfunction(name = "One")]
93pub fn py_expr_one() -> PyExpression {
94 PyExpression(Expression::one())
95}
96
97#[pymethods]
98impl PyExpression {
99 #[getter]
106 fn parameters(&self) -> Vec<String> {
107 self.0.parameters()
108 }
109 #[getter]
111 fn free_parameters(&self) -> Vec<String> {
112 self.0.free_parameters()
113 }
114 #[getter]
116 fn fixed_parameters(&self) -> Vec<String> {
117 self.0.fixed_parameters()
118 }
119 #[getter]
121 fn n_free(&self) -> usize {
122 self.0.n_free()
123 }
124 #[getter]
126 fn n_fixed(&self) -> usize {
127 self.0.n_fixed()
128 }
129 #[getter]
131 fn n_parameters(&self) -> usize {
132 self.0.n_parameters()
133 }
134 fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
147 Ok(PyEvaluator(self.0.load(&dataset.0)?))
148 }
149 fn real(&self) -> PyExpression {
151 PyExpression(self.0.real())
152 }
153 fn imag(&self) -> PyExpression {
155 PyExpression(self.0.imag())
156 }
157 fn conj(&self) -> PyExpression {
159 PyExpression(self.0.conj())
160 }
161 fn norm_sqr(&self) -> PyExpression {
163 PyExpression(self.0.norm_sqr())
164 }
165 fn fix(&self, name: &str, value: f64) -> PyResult<PyExpression> {
167 Ok(PyExpression(self.0.fix(name, value)?))
168 }
169 fn free(&self, name: &str) -> PyResult<PyExpression> {
171 Ok(PyExpression(self.0.free(name)?))
172 }
173 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyExpression> {
175 Ok(PyExpression(self.0.rename_parameter(old, new)?))
176 }
177 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyExpression> {
179 Ok(PyExpression(self.0.rename_parameters(&mapping)?))
180 }
181 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
182 if let Ok(other_expr) = other.extract::<PyExpression>() {
183 Ok(PyExpression(self.0.clone() + other_expr.0))
184 } else if let Ok(other_int) = other.extract::<usize>() {
185 if other_int == 0 {
186 Ok(PyExpression(self.0.clone()))
187 } else {
188 Err(PyTypeError::new_err(
189 "Addition with an integer for this type is only defined for 0",
190 ))
191 }
192 } else {
193 Err(PyTypeError::new_err("Unsupported operand type for +"))
194 }
195 }
196 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
197 if let Ok(other_expr) = other.extract::<PyExpression>() {
198 Ok(PyExpression(other_expr.0 + self.0.clone()))
199 } else if let Ok(other_int) = other.extract::<usize>() {
200 if other_int == 0 {
201 Ok(PyExpression(self.0.clone()))
202 } else {
203 Err(PyTypeError::new_err(
204 "Addition with an integer for this type is only defined for 0",
205 ))
206 }
207 } else {
208 Err(PyTypeError::new_err("Unsupported operand type for +"))
209 }
210 }
211 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
212 if let Ok(other_expr) = other.extract::<PyExpression>() {
213 Ok(PyExpression(self.0.clone() - other_expr.0))
214 } else {
215 Err(PyTypeError::new_err("Unsupported operand type for -"))
216 }
217 }
218 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
219 if let Ok(other_expr) = other.extract::<PyExpression>() {
220 Ok(PyExpression(other_expr.0 - self.0.clone()))
221 } else {
222 Err(PyTypeError::new_err("Unsupported operand type for -"))
223 }
224 }
225 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
226 if let Ok(other_expr) = other.extract::<PyExpression>() {
227 Ok(PyExpression(self.0.clone() * other_expr.0))
228 } else {
229 Err(PyTypeError::new_err("Unsupported operand type for *"))
230 }
231 }
232 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
233 if let Ok(other_expr) = other.extract::<PyExpression>() {
234 Ok(PyExpression(other_expr.0 * self.0.clone()))
235 } else {
236 Err(PyTypeError::new_err("Unsupported operand type for *"))
237 }
238 }
239 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
240 if let Ok(other_expr) = other.extract::<PyExpression>() {
241 Ok(PyExpression(self.0.clone() / other_expr.0))
242 } else {
243 Err(PyTypeError::new_err("Unsupported operand type for /"))
244 }
245 }
246 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
247 if let Ok(other_expr) = other.extract::<PyExpression>() {
248 Ok(PyExpression(other_expr.0 / self.0.clone()))
249 } else {
250 Err(PyTypeError::new_err("Unsupported operand type for /"))
251 }
252 }
253 fn __neg__(&self) -> PyExpression {
254 PyExpression(-self.0.clone())
255 }
256 fn __str__(&self) -> String {
257 format!("{}", self.0)
258 }
259 fn __repr__(&self) -> String {
260 format!("{:?}", self.0)
261 }
262
263 #[new]
264 fn new() -> Self {
265 Self(Expression::create_null())
266 }
267 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
268 Ok(PyBytes::new(
269 py,
270 bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
271 .map_err(LadduError::EncodeError)?
272 .as_slice(),
273 ))
274 }
275 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
276 *self = Self(
277 bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
278 .map_err(LadduError::DecodeError)?
279 .0,
280 );
281 Ok(())
282 }
283}
284
285#[pyclass(name = "Evaluator", module = "laddu")]
292#[derive(Clone)]
293pub struct PyEvaluator(pub Evaluator);
294
295#[pymethods]
296impl PyEvaluator {
297 #[getter]
305 fn parameters(&self) -> Vec<String> {
306 self.0.parameters()
307 }
308 #[getter]
310 fn free_parameters(&self) -> Vec<String> {
311 self.0.free_parameters()
312 }
313 #[getter]
315 fn fixed_parameters(&self) -> Vec<String> {
316 self.0.fixed_parameters()
317 }
318 #[getter]
320 fn n_free(&self) -> usize {
321 self.0.n_free()
322 }
323 #[getter]
325 fn n_fixed(&self) -> usize {
326 self.0.n_fixed()
327 }
328 #[getter]
330 fn n_parameters(&self) -> usize {
331 self.0.n_parameters()
332 }
333 fn fix(&self, name: &str, value: f64) -> PyResult<PyEvaluator> {
335 Ok(PyEvaluator(self.0.fix(name, value)?))
336 }
337 fn free(&self, name: &str) -> PyResult<PyEvaluator> {
339 Ok(PyEvaluator(self.0.free(name)?))
340 }
341 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyEvaluator> {
343 Ok(PyEvaluator(self.0.rename_parameter(old, new)?))
344 }
345 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyEvaluator> {
347 Ok(PyEvaluator(self.0.rename_parameters(&mapping)?))
348 }
349 #[pyo3(signature = (arg, *, strict=true))]
366 fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
367 if let Ok(string_arg) = arg.extract::<String>() {
368 if strict {
369 self.0.activate_strict(&string_arg)?;
370 } else {
371 self.0.activate(&string_arg);
372 }
373 } else if let Ok(list_arg) = arg.cast::<PyList>() {
374 let vec: Vec<String> = list_arg.extract()?;
375 if strict {
376 self.0.activate_many_strict(&vec)?;
377 } else {
378 self.0.activate_many(&vec);
379 }
380 } else {
381 return Err(PyTypeError::new_err(
382 "Argument must be either a string or a list of strings",
383 ));
384 }
385 Ok(())
386 }
387 fn activate_all(&self) {
390 self.0.activate_all();
391 }
392 #[pyo3(signature = (arg, *, strict=true))]
411 fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
412 if let Ok(string_arg) = arg.extract::<String>() {
413 if strict {
414 self.0.deactivate_strict(&string_arg)?;
415 } else {
416 self.0.deactivate(&string_arg);
417 }
418 } else if let Ok(list_arg) = arg.cast::<PyList>() {
419 let vec: Vec<String> = list_arg.extract()?;
420 if strict {
421 self.0.deactivate_many_strict(&vec)?;
422 } else {
423 self.0.deactivate_many(&vec);
424 }
425 } else {
426 return Err(PyTypeError::new_err(
427 "Argument must be either a string or a list of strings",
428 ));
429 }
430 Ok(())
431 }
432 fn deactivate_all(&self) {
435 self.0.deactivate_all();
436 }
437 #[pyo3(signature = (arg, *, strict=true))]
456 fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
457 if let Ok(string_arg) = arg.extract::<String>() {
458 if strict {
459 self.0.isolate_strict(&string_arg)?;
460 } else {
461 self.0.isolate(&string_arg);
462 }
463 } else if let Ok(list_arg) = arg.cast::<PyList>() {
464 let vec: Vec<String> = list_arg.extract()?;
465 if strict {
466 self.0.isolate_many_strict(&vec)?;
467 } else {
468 self.0.isolate_many(&vec);
469 }
470 } else {
471 return Err(PyTypeError::new_err(
472 "Argument must be either a string or a list of strings",
473 ));
474 }
475 Ok(())
476 }
477 #[pyo3(signature = (parameters, *, threads=None))]
497 #[cfg_attr(not(feature = "rayon"), allow(unused_variables))]
498 fn evaluate<'py>(
499 &self,
500 py: Python<'py>,
501 parameters: Vec<f64>,
502 threads: Option<usize>,
503 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
504 #[cfg(feature = "rayon")]
505 {
506 Ok(PyArray1::from_slice(
507 py,
508 &ThreadPoolBuilder::new()
509 .num_threads(threads.unwrap_or_else(num_cpus::get))
510 .build()
511 .map_err(LadduError::from)?
512 .install(|| self.0.evaluate(¶meters)),
513 ))
514 }
515 #[cfg(not(feature = "rayon"))]
516 {
517 Ok(PyArray1::from_slice(py, &self.0.evaluate(¶meters)))
518 }
519 }
520 #[pyo3(signature = (parameters, indices, *, threads=None))]
542 #[cfg_attr(not(feature = "rayon"), allow(unused_variables))]
543 fn evaluate_batch<'py>(
544 &self,
545 py: Python<'py>,
546 parameters: Vec<f64>,
547 indices: Vec<usize>,
548 threads: Option<usize>,
549 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
550 #[cfg(feature = "rayon")]
551 {
552 Ok(PyArray1::from_slice(
553 py,
554 &ThreadPoolBuilder::new()
555 .num_threads(threads.unwrap_or_else(num_cpus::get))
556 .build()
557 .map_err(LadduError::from)?
558 .install(|| self.0.evaluate_batch(¶meters, &indices)),
559 ))
560 }
561 #[cfg(not(feature = "rayon"))]
562 {
563 Ok(PyArray1::from_slice(
564 py,
565 &self.0.evaluate_batch(¶meters, &indices),
566 ))
567 }
568 }
569 #[pyo3(signature = (parameters, *, threads=None))]
590 #[cfg_attr(not(feature = "rayon"), allow(unused_variables))]
591 fn evaluate_gradient<'py>(
592 &self,
593 py: Python<'py>,
594 parameters: Vec<f64>,
595 threads: Option<usize>,
596 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
597 #[cfg(feature = "rayon")]
598 {
599 Ok(PyArray2::from_vec2(
600 py,
601 &ThreadPoolBuilder::new()
602 .num_threads(threads.unwrap_or_else(num_cpus::get))
603 .build()
604 .map_err(LadduError::from)?
605 .install(|| {
606 self.0
607 .evaluate_gradient(¶meters)
608 .iter()
609 .map(|grad| grad.data.as_vec().to_vec())
610 .collect::<Vec<Vec<Complex64>>>()
611 }),
612 )
613 .map_err(LadduError::NumpyError)?)
614 }
615 #[cfg(not(feature = "rayon"))]
616 {
617 Ok(PyArray2::from_vec2(
618 py,
619 &self
620 .0
621 .evaluate_gradient(¶meters)
622 .iter()
623 .map(|grad| grad.data.as_vec().to_vec())
624 .collect::<Vec<Vec<Complex64>>>(),
625 )
626 .map_err(LadduError::NumpyError)?)
627 }
628 }
629 #[pyo3(signature = (parameters, indices, *, threads=None))]
652 #[cfg_attr(not(feature = "rayon"), allow(unused_variables))]
653 fn evaluate_gradient_batch<'py>(
654 &self,
655 py: Python<'py>,
656 parameters: Vec<f64>,
657 indices: Vec<usize>,
658 threads: Option<usize>,
659 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
660 #[cfg(feature = "rayon")]
661 {
662 Ok(PyArray2::from_vec2(
663 py,
664 &ThreadPoolBuilder::new()
665 .num_threads(threads.unwrap_or_else(num_cpus::get))
666 .build()
667 .map_err(LadduError::from)?
668 .install(|| {
669 self.0
670 .evaluate_gradient_batch(¶meters, &indices)
671 .iter()
672 .map(|grad| grad.data.as_vec().to_vec())
673 .collect::<Vec<Vec<Complex64>>>()
674 }),
675 )
676 .map_err(LadduError::NumpyError)?)
677 }
678 #[cfg(not(feature = "rayon"))]
679 {
680 Ok(PyArray2::from_vec2(
681 py,
682 &self
683 .0
684 .evaluate_gradient_batch(¶meters, &indices)
685 .iter()
686 .map(|grad| grad.data.as_vec().to_vec())
687 .collect::<Vec<Vec<Complex64>>>(),
688 )
689 .map_err(LadduError::NumpyError)?)
690 }
691 }
692}
693
694#[pyclass(name = "ParameterLike", module = "laddu")]
703#[derive(Clone)]
704pub struct PyParameterLike(pub ParameterLike);
705
706#[pyfunction(name = "parameter")]
723pub fn py_parameter(name: &str) -> PyParameterLike {
724 PyParameterLike(parameter(name))
725}
726
727#[pyfunction(name = "constant")]
742pub fn py_constant(name: &str, value: f64) -> PyParameterLike {
743 PyParameterLike(constant(name, value))
744}
745
746#[pyfunction(name = "TestAmplitude")]
748pub fn py_test_amplitude(
749 name: &str,
750 re: PyParameterLike,
751 im: PyParameterLike,
752) -> PyResult<PyExpression> {
753 Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
754}