1use crate::data::PyDataset;
2use laddu_core::{
3 amplitudes::{constant, parameter, Evaluator, Expression, ParameterLike, TestAmplitude},
4 f64, LadduError, ReadWrite,
5};
6use num::complex::Complex64;
7use numpy::{PyArray1, PyArray2};
8use pyo3::{
9 exceptions::PyTypeError,
10 prelude::*,
11 types::{PyBytes, PyList},
12};
13#[cfg(feature = "rayon")]
14use rayon::ThreadPoolBuilder;
15use std::collections::HashMap;
16
17#[pyclass(name = "Expression", module = "laddu")]
20#[derive(Clone)]
21pub struct PyExpression(pub Expression);
22
23#[pyfunction(name = "expr_sum")]
26pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
27 if terms.is_empty() {
28 return Ok(PyExpression(Expression::zero()));
29 }
30 if terms.len() == 1 {
31 let term = &terms[0];
32 if let Ok(expression) = term.extract::<PyExpression>() {
33 return Ok(expression);
34 }
35 return Err(PyTypeError::new_err("Item is not a PyExpression"));
36 }
37 let mut iter = terms.iter();
38 let Some(first_term) = iter.next() else {
39 return Ok(PyExpression(Expression::zero()));
40 };
41 let PyExpression(mut summation) = first_term
42 .extract::<PyExpression>()
43 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
44 for term in iter {
45 let PyExpression(expr) = term
46 .extract::<PyExpression>()
47 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
48 summation = summation + expr;
49 }
50 Ok(PyExpression(summation))
51}
52
53#[pyfunction(name = "expr_product")]
56pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
57 if terms.is_empty() {
58 return Ok(PyExpression(Expression::one()));
59 }
60 if terms.len() == 1 {
61 let term = &terms[0];
62 if let Ok(expression) = term.extract::<PyExpression>() {
63 return Ok(expression);
64 }
65 return Err(PyTypeError::new_err("Item is not a PyExpression"));
66 }
67 let mut iter = terms.iter();
68 let Some(first_term) = iter.next() else {
69 return Ok(PyExpression(Expression::one()));
70 };
71 let PyExpression(mut product) = first_term
72 .extract::<PyExpression>()
73 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
74 for term in iter {
75 let PyExpression(expr) = term
76 .extract::<PyExpression>()
77 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
78 product = product * expr;
79 }
80 Ok(PyExpression(product))
81}
82
83#[pyfunction(name = "Zero")]
86pub fn py_expr_zero() -> PyExpression {
87 PyExpression(Expression::zero())
88}
89
90#[pyfunction(name = "One")]
93pub fn py_expr_one() -> PyExpression {
94 PyExpression(Expression::one())
95}
96
97#[pymethods]
98impl PyExpression {
99 #[getter]
106 fn parameters(&self) -> Vec<String> {
107 self.0.parameters()
108 }
109 #[getter]
111 fn free_parameters(&self) -> Vec<String> {
112 self.0.free_parameters()
113 }
114 #[getter]
116 fn fixed_parameters(&self) -> Vec<String> {
117 self.0.fixed_parameters()
118 }
119 #[getter]
121 fn n_free(&self) -> usize {
122 self.0.n_free()
123 }
124 #[getter]
126 fn n_fixed(&self) -> usize {
127 self.0.n_fixed()
128 }
129 #[getter]
131 fn n_parameters(&self) -> usize {
132 self.0.n_parameters()
133 }
134 fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
147 Ok(PyEvaluator(self.0.load(&dataset.0)?))
148 }
149 fn real(&self) -> PyExpression {
151 PyExpression(self.0.real())
152 }
153 fn imag(&self) -> PyExpression {
155 PyExpression(self.0.imag())
156 }
157 fn conj(&self) -> PyExpression {
159 PyExpression(self.0.conj())
160 }
161 fn norm_sqr(&self) -> PyExpression {
163 PyExpression(self.0.norm_sqr())
164 }
165 fn fix(&self, name: &str, value: f64) -> PyResult<PyExpression> {
167 Ok(PyExpression(self.0.fix(name, value)?))
168 }
169 fn free(&self, name: &str) -> PyResult<PyExpression> {
171 Ok(PyExpression(self.0.free(name)?))
172 }
173 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyExpression> {
175 Ok(PyExpression(self.0.rename_parameter(old, new)?))
176 }
177 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyExpression> {
179 Ok(PyExpression(self.0.rename_parameters(&mapping)?))
180 }
181 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
182 if let Ok(other_expr) = other.extract::<PyExpression>() {
183 Ok(PyExpression(self.0.clone() + other_expr.0))
184 } else if let Ok(other_int) = other.extract::<usize>() {
185 if other_int == 0 {
186 Ok(PyExpression(self.0.clone()))
187 } else {
188 Err(PyTypeError::new_err(
189 "Addition with an integer for this type is only defined for 0",
190 ))
191 }
192 } else {
193 Err(PyTypeError::new_err("Unsupported operand type for +"))
194 }
195 }
196 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
197 if let Ok(other_expr) = other.extract::<PyExpression>() {
198 Ok(PyExpression(other_expr.0 + self.0.clone()))
199 } else if let Ok(other_int) = other.extract::<usize>() {
200 if other_int == 0 {
201 Ok(PyExpression(self.0.clone()))
202 } else {
203 Err(PyTypeError::new_err(
204 "Addition with an integer for this type is only defined for 0",
205 ))
206 }
207 } else {
208 Err(PyTypeError::new_err("Unsupported operand type for +"))
209 }
210 }
211 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
212 if let Ok(other_expr) = other.extract::<PyExpression>() {
213 Ok(PyExpression(self.0.clone() - other_expr.0))
214 } else {
215 Err(PyTypeError::new_err("Unsupported operand type for -"))
216 }
217 }
218 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
219 if let Ok(other_expr) = other.extract::<PyExpression>() {
220 Ok(PyExpression(other_expr.0 - self.0.clone()))
221 } else {
222 Err(PyTypeError::new_err("Unsupported operand type for -"))
223 }
224 }
225 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
226 if let Ok(other_expr) = other.extract::<PyExpression>() {
227 Ok(PyExpression(self.0.clone() * other_expr.0))
228 } else {
229 Err(PyTypeError::new_err("Unsupported operand type for *"))
230 }
231 }
232 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
233 if let Ok(other_expr) = other.extract::<PyExpression>() {
234 Ok(PyExpression(other_expr.0 * self.0.clone()))
235 } else {
236 Err(PyTypeError::new_err("Unsupported operand type for *"))
237 }
238 }
239 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
240 if let Ok(other_expr) = other.extract::<PyExpression>() {
241 Ok(PyExpression(self.0.clone() / other_expr.0))
242 } else {
243 Err(PyTypeError::new_err("Unsupported operand type for /"))
244 }
245 }
246 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
247 if let Ok(other_expr) = other.extract::<PyExpression>() {
248 Ok(PyExpression(other_expr.0 / self.0.clone()))
249 } else {
250 Err(PyTypeError::new_err("Unsupported operand type for /"))
251 }
252 }
253 fn __neg__(&self) -> PyExpression {
254 PyExpression(-self.0.clone())
255 }
256 fn __str__(&self) -> String {
257 format!("{}", self.0)
258 }
259 fn __repr__(&self) -> String {
260 format!("{:?}", self.0)
261 }
262
263 #[new]
264 fn new() -> Self {
265 Self(Expression::create_null())
266 }
267 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
268 Ok(PyBytes::new(
269 py,
270 bincode::serde::encode_to_vec(&self.0, bincode::config::standard())
271 .map_err(LadduError::EncodeError)?
272 .as_slice(),
273 ))
274 }
275 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
276 *self = Self(
277 bincode::serde::decode_from_slice(state.as_bytes(), bincode::config::standard())
278 .map_err(LadduError::DecodeError)?
279 .0,
280 );
281 Ok(())
282 }
283}
284
285#[pyclass(name = "Evaluator", module = "laddu")]
292#[derive(Clone)]
293pub struct PyEvaluator(pub Evaluator);
294
295#[pymethods]
296impl PyEvaluator {
297 #[getter]
305 fn parameters(&self) -> Vec<String> {
306 self.0.parameters()
307 }
308 #[getter]
310 fn free_parameters(&self) -> Vec<String> {
311 self.0.free_parameters()
312 }
313 #[getter]
315 fn fixed_parameters(&self) -> Vec<String> {
316 self.0.fixed_parameters()
317 }
318 #[getter]
320 fn n_free(&self) -> usize {
321 self.0.n_free()
322 }
323 #[getter]
325 fn n_fixed(&self) -> usize {
326 self.0.n_fixed()
327 }
328 #[getter]
330 fn n_parameters(&self) -> usize {
331 self.0.n_parameters()
332 }
333 fn fix(&self, name: &str, value: f64) -> PyResult<PyEvaluator> {
335 Ok(PyEvaluator(self.0.fix(name, value)?))
336 }
337 fn free(&self, name: &str) -> PyResult<PyEvaluator> {
339 Ok(PyEvaluator(self.0.free(name)?))
340 }
341 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyEvaluator> {
343 Ok(PyEvaluator(self.0.rename_parameter(old, new)?))
344 }
345 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyEvaluator> {
347 Ok(PyEvaluator(self.0.rename_parameters(&mapping)?))
348 }
349 #[pyo3(signature = (arg, *, strict=true))]
366 fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
367 if let Ok(string_arg) = arg.extract::<String>() {
368 if strict {
369 self.0.activate_strict(&string_arg)?;
370 } else {
371 self.0.activate(&string_arg);
372 }
373 } else if let Ok(list_arg) = arg.cast::<PyList>() {
374 let vec: Vec<String> = list_arg.extract()?;
375 if strict {
376 self.0.activate_many_strict(&vec)?;
377 } else {
378 self.0.activate_many(&vec);
379 }
380 } else {
381 return Err(PyTypeError::new_err(
382 "Argument must be either a string or a list of strings",
383 ));
384 }
385 Ok(())
386 }
387 fn activate_all(&self) {
390 self.0.activate_all();
391 }
392 #[pyo3(signature = (arg, *, strict=true))]
411 fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
412 if let Ok(string_arg) = arg.extract::<String>() {
413 if strict {
414 self.0.deactivate_strict(&string_arg)?;
415 } else {
416 self.0.deactivate(&string_arg);
417 }
418 } else if let Ok(list_arg) = arg.cast::<PyList>() {
419 let vec: Vec<String> = list_arg.extract()?;
420 if strict {
421 self.0.deactivate_many_strict(&vec)?;
422 } else {
423 self.0.deactivate_many(&vec);
424 }
425 } else {
426 return Err(PyTypeError::new_err(
427 "Argument must be either a string or a list of strings",
428 ));
429 }
430 Ok(())
431 }
432 fn deactivate_all(&self) {
435 self.0.deactivate_all();
436 }
437 #[pyo3(signature = (arg, *, strict=true))]
456 fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
457 if let Ok(string_arg) = arg.extract::<String>() {
458 if strict {
459 self.0.isolate_strict(&string_arg)?;
460 } else {
461 self.0.isolate(&string_arg);
462 }
463 } else if let Ok(list_arg) = arg.cast::<PyList>() {
464 let vec: Vec<String> = list_arg.extract()?;
465 if strict {
466 self.0.isolate_many_strict(&vec)?;
467 } else {
468 self.0.isolate_many(&vec);
469 }
470 } else {
471 return Err(PyTypeError::new_err(
472 "Argument must be either a string or a list of strings",
473 ));
474 }
475 Ok(())
476 }
477 #[pyo3(signature = (parameters, *, threads=None))]
497 fn evaluate<'py>(
498 &self,
499 py: Python<'py>,
500 parameters: Vec<f64>,
501 threads: Option<usize>,
502 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
503 #[cfg(feature = "rayon")]
504 {
505 Ok(PyArray1::from_slice(
506 py,
507 &ThreadPoolBuilder::new()
508 .num_threads(threads.unwrap_or_else(num_cpus::get))
509 .build()
510 .map_err(LadduError::from)?
511 .install(|| self.0.evaluate(¶meters)),
512 ))
513 }
514 #[cfg(not(feature = "rayon"))]
515 {
516 Ok(PyArray1::from_slice(py, &self.0.evaluate(¶meters)))
517 }
518 }
519 #[pyo3(signature = (parameters, indices, *, threads=None))]
541 fn evaluate_batch<'py>(
542 &self,
543 py: Python<'py>,
544 parameters: Vec<f64>,
545 indices: Vec<usize>,
546 threads: Option<usize>,
547 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
548 #[cfg(feature = "rayon")]
549 {
550 Ok(PyArray1::from_slice(
551 py,
552 &ThreadPoolBuilder::new()
553 .num_threads(threads.unwrap_or_else(num_cpus::get))
554 .build()
555 .map_err(LadduError::from)?
556 .install(|| self.0.evaluate_batch(¶meters, &indices)),
557 ))
558 }
559 #[cfg(not(feature = "rayon"))]
560 {
561 Ok(PyArray1::from_slice(
562 py,
563 &self.0.evaluate_batch(¶meters, &indices),
564 ))
565 }
566 }
567 #[pyo3(signature = (parameters, *, threads=None))]
588 fn evaluate_gradient<'py>(
589 &self,
590 py: Python<'py>,
591 parameters: Vec<f64>,
592 threads: Option<usize>,
593 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
594 #[cfg(feature = "rayon")]
595 {
596 Ok(PyArray2::from_vec2(
597 py,
598 &ThreadPoolBuilder::new()
599 .num_threads(threads.unwrap_or_else(num_cpus::get))
600 .build()
601 .map_err(LadduError::from)?
602 .install(|| {
603 self.0
604 .evaluate_gradient(¶meters)
605 .iter()
606 .map(|grad| grad.data.as_vec().to_vec())
607 .collect::<Vec<Vec<Complex64>>>()
608 }),
609 )
610 .map_err(LadduError::NumpyError)?)
611 }
612 #[cfg(not(feature = "rayon"))]
613 {
614 Ok(PyArray2::from_vec2(
615 py,
616 &self
617 .0
618 .evaluate_gradient(¶meters)
619 .iter()
620 .map(|grad| grad.data.as_vec().to_vec())
621 .collect::<Vec<Vec<Complex64>>>(),
622 )
623 .map_err(LadduError::NumpyError)?)
624 }
625 }
626 #[pyo3(signature = (parameters, indices, *, threads=None))]
649 fn evaluate_gradient_batch<'py>(
650 &self,
651 py: Python<'py>,
652 parameters: Vec<f64>,
653 indices: Vec<usize>,
654 threads: Option<usize>,
655 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
656 #[cfg(feature = "rayon")]
657 {
658 Ok(PyArray2::from_vec2(
659 py,
660 &ThreadPoolBuilder::new()
661 .num_threads(threads.unwrap_or_else(num_cpus::get))
662 .build()
663 .map_err(LadduError::from)?
664 .install(|| {
665 self.0
666 .evaluate_gradient_batch(¶meters, &indices)
667 .iter()
668 .map(|grad| grad.data.as_vec().to_vec())
669 .collect::<Vec<Vec<Complex64>>>()
670 }),
671 )
672 .map_err(LadduError::NumpyError)?)
673 }
674 #[cfg(not(feature = "rayon"))]
675 {
676 Ok(PyArray2::from_vec2(
677 py,
678 &self
679 .0
680 .evaluate_gradient_batch(¶meters, &indices)
681 .iter()
682 .map(|grad| grad.data.as_vec().to_vec())
683 .collect::<Vec<Vec<Complex64>>>(),
684 )
685 .map_err(LadduError::NumpyError)?)
686 }
687 }
688}
689
690#[pyclass(name = "ParameterLike", module = "laddu")]
699#[derive(Clone)]
700pub struct PyParameterLike(pub ParameterLike);
701
702#[pyfunction(name = "parameter")]
719pub fn py_parameter(name: &str) -> PyParameterLike {
720 PyParameterLike(parameter(name))
721}
722
723#[pyfunction(name = "constant")]
738pub fn py_constant(name: &str, value: f64) -> PyParameterLike {
739 PyParameterLike(constant(name, value))
740}
741
742#[pyfunction(name = "TestAmplitude")]
744pub fn py_test_amplitude(
745 name: &str,
746 re: PyParameterLike,
747 im: PyParameterLike,
748) -> PyResult<PyExpression> {
749 Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
750}