use crate::data::PyDataset;
use laddu_core::{
amplitudes::{constant, parameter, Evaluator, Expression, ParameterLike, TestAmplitude},
f64, LadduError, LadduResult, ReadWrite, ThreadPoolManager,
};
use num::complex::Complex64;
use numpy::{PyArray1, PyArray2};
use pyo3::{
exceptions::PyTypeError,
prelude::*,
types::{PyBytes, PyList},
};
use std::collections::HashMap;
fn install_with_threads<R: Send>(
threads: Option<usize>,
op: impl FnOnce() -> R + Send,
) -> LadduResult<R> {
ThreadPoolManager::shared().install(threads, op)
}
#[pyclass(name = "Expression", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyExpression(pub Expression);
#[pyfunction(name = "expr_sum")]
pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
if terms.is_empty() {
return Ok(PyExpression(Expression::zero()));
}
if terms.len() == 1 {
let term = &terms[0];
if let Ok(expression) = term.extract::<PyExpression>() {
return Ok(expression);
}
return Err(PyTypeError::new_err("Item is not a PyExpression"));
}
let mut iter = terms.iter();
let Some(first_term) = iter.next() else {
return Ok(PyExpression(Expression::zero()));
};
let PyExpression(mut summation) = first_term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
for term in iter {
let PyExpression(expr) = term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
summation = summation + expr;
}
Ok(PyExpression(summation))
}
#[pyfunction(name = "expr_product")]
pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
if terms.is_empty() {
return Ok(PyExpression(Expression::one()));
}
if terms.len() == 1 {
let term = &terms[0];
if let Ok(expression) = term.extract::<PyExpression>() {
return Ok(expression);
}
return Err(PyTypeError::new_err("Item is not a PyExpression"));
}
let mut iter = terms.iter();
let Some(first_term) = iter.next() else {
return Ok(PyExpression(Expression::one()));
};
let PyExpression(mut product) = first_term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
for term in iter {
let PyExpression(expr) = term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
product = product * expr;
}
Ok(PyExpression(product))
}
#[pyfunction(name = "Zero")]
pub fn py_expr_zero() -> PyExpression {
PyExpression(Expression::zero())
}
#[pyfunction(name = "One")]
pub fn py_expr_one() -> PyExpression {
PyExpression(Expression::one())
}
#[pymethods]
impl PyExpression {
#[getter]
fn parameters(&self) -> Vec<String> {
self.0.parameters()
}
#[getter]
fn free_parameters(&self) -> Vec<String> {
self.0.free_parameters()
}
#[getter]
fn fixed_parameters(&self) -> Vec<String> {
self.0.fixed_parameters()
}
#[getter]
fn n_free(&self) -> usize {
self.0.n_free()
}
#[getter]
fn n_fixed(&self) -> usize {
self.0.n_fixed()
}
#[getter]
fn n_parameters(&self) -> usize {
self.0.n_parameters()
}
fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
Ok(PyEvaluator(self.0.load(&dataset.0)?))
}
fn real(&self) -> PyExpression {
PyExpression(self.0.real())
}
fn imag(&self) -> PyExpression {
PyExpression(self.0.imag())
}
fn conj(&self) -> PyExpression {
PyExpression(self.0.conj())
}
fn norm_sqr(&self) -> PyExpression {
PyExpression(self.0.norm_sqr())
}
fn fix(&self, name: &str, value: f64) -> PyResult<PyExpression> {
Ok(PyExpression(self.0.fix(name, value)?))
}
fn free(&self, name: &str) -> PyResult<PyExpression> {
Ok(PyExpression(self.0.free(name)?))
}
fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyExpression> {
Ok(PyExpression(self.0.rename_parameter(old, new)?))
}
fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyExpression> {
Ok(PyExpression(self.0.rename_parameters(&mapping)?))
}
fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() + other_expr.0))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(PyExpression(self.0.clone()))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 + self.0.clone()))
} else if let Ok(other_int) = other.extract::<usize>() {
if other_int == 0 {
Ok(PyExpression(self.0.clone()))
} else {
Err(PyTypeError::new_err(
"Addition with an integer for this type is only defined for 0",
))
}
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() - other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for -"))
}
}
fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 - self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for -"))
}
}
fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() * other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 * self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() / other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for /"))
}
}
fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 / self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for /"))
}
}
fn __neg__(&self) -> PyExpression {
PyExpression(-self.0.clone())
}
fn __str__(&self) -> String {
format!("{}", self.0)
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
#[new]
fn new() -> Self {
Self(Expression::create_null())
}
fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
Ok(PyBytes::new(
py,
serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
.map_err(LadduError::PickleError)?
.as_slice(),
))
}
fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
*self = Self(
serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
.map_err(LadduError::PickleError)?,
);
Ok(())
}
}
#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyEvaluator(pub Evaluator);
#[pymethods]
impl PyEvaluator {
#[getter]
fn parameters(&self) -> Vec<String> {
self.0.parameters()
}
#[getter]
fn free_parameters(&self) -> Vec<String> {
self.0.free_parameters()
}
#[getter]
fn fixed_parameters(&self) -> Vec<String> {
self.0.fixed_parameters()
}
#[getter]
fn n_free(&self) -> usize {
self.0.n_free()
}
#[getter]
fn n_fixed(&self) -> usize {
self.0.n_fixed()
}
#[getter]
fn n_parameters(&self) -> usize {
self.0.n_parameters()
}
fn fix(&self, name: &str, value: f64) -> PyResult<PyEvaluator> {
Ok(PyEvaluator(self.0.fix(name, value)?))
}
fn free(&self, name: &str) -> PyResult<PyEvaluator> {
Ok(PyEvaluator(self.0.free(name)?))
}
fn rename_parameter(&self, old: &str, new: &str) -> PyResult<PyEvaluator> {
Ok(PyEvaluator(self.0.rename_parameter(old, new)?))
}
fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<PyEvaluator> {
Ok(PyEvaluator(self.0.rename_parameters(&mapping)?))
}
#[pyo3(signature = (arg, *, strict=true))]
fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.activate_strict(&string_arg)?;
} else {
self.0.activate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.activate_many_strict(&vec)?;
} else {
self.0.activate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
fn activate_all(&self) {
self.0.activate_all();
}
#[pyo3(signature = (arg, *, strict=true))]
fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.deactivate_strict(&string_arg)?;
} else {
self.0.deactivate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.deactivate_many_strict(&vec)?;
} else {
self.0.deactivate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
fn deactivate_all(&self) {
self.0.deactivate_all();
}
#[pyo3(signature = (arg, *, strict=true))]
fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.isolate_strict(&string_arg)?;
} else {
self.0.isolate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.isolate_many_strict(&vec)?;
} else {
self.0.isolate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
#[getter]
fn active_mask(&self) -> Vec<bool> {
self.0.active_mask()
}
fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
self.0.set_active_mask(&mask)?;
Ok(())
}
#[pyo3(signature = (parameters, *, threads=None))]
fn evaluate<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
let values = install_with_threads(threads, || self.0.evaluate(¶meters))?;
Ok(PyArray1::from_slice(py, &values))
}
#[pyo3(signature = (parameters, indices, *, threads=None))]
fn evaluate_batch<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
indices: Vec<usize>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
let values =
install_with_threads(threads, || self.0.evaluate_batch(¶meters, &indices))?;
Ok(PyArray1::from_slice(py, &values))
}
#[pyo3(signature = (parameters, *, threads=None))]
fn evaluate_gradient<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
let gradients = install_with_threads(threads, || {
self.0
.evaluate_gradient(¶meters)
.iter()
.map(|grad| grad.data.as_vec().to_vec())
.collect::<Vec<Vec<Complex64>>>()
})?;
Ok(PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?)
}
#[pyo3(signature = (parameters, indices, *, threads=None))]
fn evaluate_gradient_batch<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
indices: Vec<usize>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
let gradients = install_with_threads(threads, || {
self.0
.evaluate_gradient_batch(¶meters, &indices)
.iter()
.map(|grad| grad.data.as_vec().to_vec())
.collect::<Vec<Vec<Complex64>>>()
})?;
Ok(PyArray2::from_vec2(py, &gradients).map_err(LadduError::NumpyError)?)
}
}
#[pyclass(name = "ParameterLike", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyParameterLike(pub ParameterLike);
#[pyfunction(name = "parameter")]
pub fn py_parameter(name: &str) -> PyParameterLike {
PyParameterLike(parameter(name))
}
#[pyfunction(name = "constant")]
pub fn py_constant(name: &str, value: f64) -> PyParameterLike {
PyParameterLike(constant(name, value))
}
#[pyfunction(name = "TestAmplitude")]
pub fn py_test_amplitude(
name: &str,
re: PyParameterLike,
im: PyParameterLike,
) -> PyResult<PyExpression> {
Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
}