use crate::data::PyDataset;
use laddu_core::{
amplitudes::{Evaluator, Expression, Parameter, TestAmplitude},
CompiledExpression, LadduError, LadduResult, ReadWrite, ThreadPoolManager,
};
use num::complex::Complex64;
use numpy::{PyArray1, PyArray2};
use pyo3::{
exceptions::PyTypeError,
prelude::*,
types::{PyBytes, PyList, PyTuple},
};
use std::collections::HashMap;
fn install_with_threads<R: Send>(
threads: Option<usize>,
op: impl FnOnce() -> R + Send,
) -> LadduResult<R> {
ThreadPoolManager::shared().install(threads, op)
}
#[pyclass(name = "Expression", module = "laddu", skip_from_py_object)]
#[derive(Clone)]
pub struct PyExpression(pub Expression);
impl<'py> FromPyObject<'_, 'py> for PyExpression {
type Error = PyErr;
fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
if let Ok(obj) = obj.cast::<PyExpression>() {
Ok(obj.borrow().clone())
} else if let Ok(obj) = obj.extract::<f64>() {
Ok(Self(obj.into()))
} else if let Ok(obj) = obj.extract::<Complex64>() {
Ok(Self(obj.into()))
} else {
Err(PyTypeError::new_err("Failed to extract Expression"))
}
}
}
#[pyfunction(name = "expr_sum")]
pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
if terms.is_empty() {
return Ok(PyExpression(Expression::zero()));
}
if terms.len() == 1 {
let term = &terms[0];
if let Ok(expression) = term.extract::<PyExpression>() {
return Ok(expression);
}
return Err(PyTypeError::new_err("Item is not a PyExpression"));
}
let mut iter = terms.iter();
let Some(first_term) = iter.next() else {
return Ok(PyExpression(Expression::zero()));
};
let PyExpression(mut summation) = first_term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
for term in iter {
let PyExpression(expr) = term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
summation = summation + expr;
}
Ok(PyExpression(summation))
}
#[pyfunction(name = "expr_product")]
pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
if terms.is_empty() {
return Ok(PyExpression(Expression::one()));
}
if terms.len() == 1 {
let term = &terms[0];
if let Ok(expression) = term.extract::<PyExpression>() {
return Ok(expression);
}
return Err(PyTypeError::new_err("Item is not a PyExpression"));
}
let mut iter = terms.iter();
let Some(first_term) = iter.next() else {
return Ok(PyExpression(Expression::one()));
};
let PyExpression(mut product) = first_term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
for term in iter {
let PyExpression(expr) = term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
product = product * expr;
}
Ok(PyExpression(product))
}
#[pyfunction(name = "Zero")]
pub fn py_expr_zero() -> PyExpression {
PyExpression(Expression::zero())
}
#[pyfunction(name = "One")]
pub fn py_expr_one() -> PyExpression {
PyExpression(Expression::one())
}
#[pymethods]
impl PyExpression {
#[getter]
fn parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
PyTuple::new(py, self.0.parameters())
}
#[getter]
fn free_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
PyTuple::new(py, self.0.free_parameters())
}
#[getter]
fn fixed_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
PyTuple::new(py, self.0.fixed_parameters())
}
#[getter]
fn n_free(&self) -> usize {
self.0.n_free()
}
#[getter]
fn n_fixed(&self) -> usize {
self.0.n_fixed()
}
#[getter]
fn n_parameters(&self) -> usize {
self.0.n_parameters()
}
fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
Ok(PyEvaluator(self.0.load(&dataset.0)?))
}
fn real(&self) -> PyExpression {
PyExpression(self.0.real())
}
fn imag(&self) -> PyExpression {
PyExpression(self.0.imag())
}
fn conj(&self) -> PyExpression {
PyExpression(self.0.conj())
}
fn norm_sqr(&self) -> PyExpression {
PyExpression(self.0.norm_sqr())
}
fn sqrt(&self) -> PyExpression {
PyExpression(self.0.sqrt())
}
fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(value) = power.extract::<i32>() {
Ok(PyExpression(self.0.powi(value)))
} else if let Ok(value) = power.extract::<f64>() {
Ok(PyExpression(self.0.powf(value)))
} else if let Ok(expression) = power.extract::<PyExpression>() {
Ok(PyExpression(self.0.pow(&expression.0)))
} else {
Err(PyTypeError::new_err(
"power must be an int, float, or Expression",
))
}
}
fn exp(&self) -> PyExpression {
PyExpression(self.0.exp())
}
fn sin(&self) -> PyExpression {
PyExpression(self.0.sin())
}
fn cos(&self) -> PyExpression {
PyExpression(self.0.cos())
}
fn log(&self) -> PyExpression {
PyExpression(self.0.log())
}
fn cis(&self) -> PyExpression {
PyExpression(self.0.cis())
}
fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
Ok(self.0.fix_parameter(name, value)?)
}
fn free_parameter(&self, name: &str) -> PyResult<()> {
Ok(self.0.free_parameter(name)?)
}
fn rename_parameter(&mut self, old: &str, new: &str) -> PyResult<()> {
Ok(self.0.rename_parameter(old, new)?)
}
fn rename_parameters(&mut self, mapping: HashMap<String, String>) -> PyResult<()> {
Ok(self.0.rename_parameters(&mapping)?)
}
#[getter]
fn compiled_expression(&self) -> PyCompiledExpression {
PyCompiledExpression(self.0.compiled_expression())
}
fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() + other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 + self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() - other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for -"))
}
}
fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 - self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for -"))
}
}
fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() * other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 * self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() / other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for /"))
}
}
fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 / self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for /"))
}
}
fn __neg__(&self) -> PyExpression {
PyExpression(-self.0.clone())
}
fn __str__(&self) -> String {
format!("{}", self.0)
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
#[new]
fn new() -> Self {
Self(Expression::create_null())
}
fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
Ok(PyBytes::new(
py,
serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
.map_err(LadduError::PickleError)?
.as_slice(),
))
}
fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
*self = Self(
serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
.map_err(LadduError::PickleError)?,
);
Ok(())
}
}
#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyEvaluator(pub Evaluator);
#[pymethods]
impl PyEvaluator {
#[getter]
fn parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
PyTuple::new(py, self.0.parameters())
}
#[getter]
fn free_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
PyTuple::new(py, self.0.free_parameters())
}
#[getter]
fn fixed_parameters<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
PyTuple::new(py, self.0.fixed_parameters())
}
#[getter]
fn n_free(&self) -> usize {
self.0.n_free()
}
#[getter]
fn n_fixed(&self) -> usize {
self.0.n_fixed()
}
#[getter]
fn n_parameters(&self) -> usize {
self.0.n_parameters()
}
fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
Ok(self.0.fix_parameter(name, value)?)
}
fn free_parameter(&self, name: &str) -> PyResult<()> {
Ok(self.0.free_parameter(name)?)
}
fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
Ok(self.0.rename_parameter(old, new)?)
}
fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
Ok(self.0.rename_parameters(&mapping)?)
}
#[pyo3(signature = (arg, *, strict=true))]
fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.activate_strict(&string_arg)?;
} else {
self.0.activate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.activate_many_strict(&vec)?;
} else {
self.0.activate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
fn activate_all(&self) {
self.0.activate_all();
}
#[pyo3(signature = (arg, *, strict=true))]
fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.deactivate_strict(&string_arg)?;
} else {
self.0.deactivate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.deactivate_many_strict(&vec)?;
} else {
self.0.deactivate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
fn deactivate_all(&self) {
self.0.deactivate_all();
}
#[pyo3(signature = (arg, *, strict=true))]
fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.isolate_strict(&string_arg)?;
} else {
self.0.isolate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.isolate_many_strict(&vec)?;
} else {
self.0.isolate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
#[getter]
fn active_mask(&self) -> Vec<bool> {
self.0.active_mask()
}
fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
self.0.set_active_mask(&mask)?;
Ok(())
}
#[getter]
fn compiled_expression(&self) -> PyCompiledExpression {
PyCompiledExpression(self.0.compiled_expression())
}
#[getter]
fn expression(&self) -> PyExpression {
PyExpression(self.0.expression())
}
#[pyo3(signature = (parameters, *, threads=None))]
fn evaluate<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
let values = install_with_threads(threads, || self.0.evaluate(¶meters))?;
Ok(PyArray1::from_slice(py, &values?))
}
#[pyo3(signature = (parameters, indices, *, threads=None))]
fn evaluate_batch<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
indices: Vec<usize>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
let values =
install_with_threads(threads, || self.0.evaluate_batch(¶meters, &indices))?;
Ok(PyArray1::from_slice(py, &values?))
}
#[pyo3(signature = (parameters, *, threads=None))]
fn evaluate_gradient<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
let gradients: LadduResult<_> = install_with_threads(threads, || {
Ok(self
.0
.evaluate_gradient(¶meters)?
.iter()
.map(|grad| grad.data.as_vec().to_vec())
.collect::<Vec<Vec<Complex64>>>())
})?;
Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
}
#[pyo3(signature = (parameters, indices, *, threads=None))]
fn evaluate_gradient_batch<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
indices: Vec<usize>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
let gradients: LadduResult<_> = install_with_threads(threads, || {
Ok(self
.0
.evaluate_gradient_batch(¶meters, &indices)?
.iter()
.map(|grad| grad.data.as_vec().to_vec())
.collect::<Vec<Vec<Complex64>>>())
})?;
Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
}
}
#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyCompiledExpression(pub CompiledExpression);
#[pymethods]
impl PyCompiledExpression {
fn __str__(&self) -> String {
format!("{}", self.0)
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "Parameter", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyParameter(pub Parameter);
#[pymethods]
impl PyParameter {
#[getter]
fn name(&self) -> String {
self.0.name()
}
#[getter]
fn fixed(&self) -> Option<f64> {
self.0.fixed()
}
#[getter]
fn initial(&self) -> Option<f64> {
self.0.initial()
}
#[getter]
fn bounds(&self) -> (Option<f64>, Option<f64>) {
self.0.bounds()
}
#[getter]
fn unit(&self) -> Option<String> {
self.0.unit()
}
#[getter]
fn latex(&self) -> Option<String> {
self.0.latex()
}
#[getter]
fn description(&self) -> Option<String> {
self.0.description()
}
}
#[pyfunction(name = "parameter", signature = (name, fixed=None, *, initial=None, bounds=(None, None), unit=None, latex=None, description=None))]
pub fn py_parameter(
name: &str,
fixed: Option<f64>,
initial: Option<f64>,
bounds: (Option<f64>, Option<f64>),
unit: Option<&str>,
latex: Option<&str>,
description: Option<&str>,
) -> PyParameter {
let par = Parameter::new(name);
if let Some(value) = initial {
par.set_initial(value);
}
if let Some(value) = fixed {
par.set_fixed_value(Some(value)); }
par.set_bounds(bounds.0, bounds.1);
if let Some(unit) = unit {
par.set_unit(unit);
}
if let Some(latex) = latex {
par.set_latex(latex);
}
if let Some(description) = description {
par.set_description(description);
}
PyParameter(par)
}
#[pyfunction(name = "TestAmplitude")]
pub fn py_test_amplitude(name: &str, re: PyParameter, im: PyParameter) -> PyResult<PyExpression> {
Ok(PyExpression(TestAmplitude::new(name, re.0, im.0)?))
}