use std::{array, collections::HashMap};
use laddu_amplitudes::{
angular::{
BlattWeisskopf, ClebschGordan, PhotonHelicity, PhotonPolarization, PhotonSDME, PolPhase,
Wigner3j, WignerD, Ylm, Zlm,
},
kmatrix::{
KopfKMatrixA0, KopfKMatrixA0Channel, KopfKMatrixA2, KopfKMatrixA2Channel, KopfKMatrixF0,
KopfKMatrixF0Channel, KopfKMatrixF2, KopfKMatrixF2Channel, KopfKMatrixPi1,
KopfKMatrixPi1Channel, KopfKMatrixRho, KopfKMatrixRhoChannel,
},
lookup::{LookupAxis, LookupTable},
resonance::{BreitWigner, BreitWignerNonRelativistic, Flatte, PhaseSpaceFactor, Voigt},
scalar::{ComplexScalar, PolarComplexScalar, Scalar, VariableScalar},
};
use laddu_core::{
amplitude::{Evaluator, Expression, Parameter, ParameterMap, TestAmplitude},
math::{BarrierKind, Sheet, QR_DEFAULT},
traits::Variable,
CompiledExpression, LadduError, LadduResult, ThreadPoolManager,
};
use num::complex::Complex64;
use numpy::{PyArray1, PyArray2};
use pyo3::{
exceptions::{PyTypeError, PyValueError},
prelude::*,
types::{PyAny, PyBytes, PyIterator, PyList, PyTuple},
};
use crate::{
data::PyDataset,
quantum::angular_momentum::{
parse_angular_momentum, parse_orbital_angular_momentum, parse_projection,
},
variables::{PyAngles, PyDecay, PyMandelstam, PyMass, PyPolarization, PyVariable},
};
type LookupInputs = (Vec<Box<dyn Variable>>, Vec<LookupAxis>);
macro_rules! py_kmatrix_channel {
($py_name:ident, $python_name:literal, $rust_name:path { $($variant:ident),+ $(,)? }) => {
#[pyclass(eq, name = $python_name, module = "laddu", from_py_object)]
#[derive(Clone, PartialEq)]
pub enum $py_name {
$($variant,)+
}
impl From<$py_name> for $rust_name {
fn from(value: $py_name) -> Self {
match value {
$( $py_name::$variant => Self::$variant, )+
}
}
}
};
}
py_kmatrix_channel!(
PyKopfKMatrixA0Channel,
"KopfKMatrixA0Channel",
KopfKMatrixA0Channel { PiEta, KKbar }
);
py_kmatrix_channel!(
PyKopfKMatrixA2Channel,
"KopfKMatrixA2Channel",
KopfKMatrixA2Channel {
PiEta,
KKbar,
PiEtaPrime
}
);
py_kmatrix_channel!(
PyKopfKMatrixF0Channel,
"KopfKMatrixF0Channel",
KopfKMatrixF0Channel {
PiPi,
FourPi,
KKbar,
EtaEta,
EtaEtaPrime
}
);
py_kmatrix_channel!(
PyKopfKMatrixF2Channel,
"KopfKMatrixF2Channel",
KopfKMatrixF2Channel {
PiPi,
FourPi,
KKbar,
EtaEta
}
);
py_kmatrix_channel!(
PyKopfKMatrixPi1Channel,
"KopfKMatrixPi1Channel",
KopfKMatrixPi1Channel { PiEta, PiEtaPrime }
);
py_kmatrix_channel!(
PyKopfKMatrixRhoChannel,
"KopfKMatrixRhoChannel",
KopfKMatrixRhoChannel {
PiPi,
FourPi,
KKbar
}
);
fn install_with_threads<R: Send>(
threads: Option<usize>,
op: impl FnOnce() -> R + Send,
) -> LadduResult<R> {
ThreadPoolManager::shared().install(threads, op)
}
pub(crate) fn py_tags(tags: &Bound<'_, PyTuple>) -> PyResult<Vec<String>> {
tags.iter()
.map(|tag| tag.extract::<String>())
.collect::<PyResult<Vec<_>>>()
}
#[pyclass(name = "Expression", module = "laddu", skip_from_py_object)]
#[derive(Clone)]
pub struct PyExpression(pub Expression);
impl<'py> FromPyObject<'_, 'py> for PyExpression {
type Error = PyErr;
fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
if let Ok(obj) = obj.cast::<PyExpression>() {
Ok(obj.borrow().clone())
} else if let Ok(obj) = obj.extract::<f64>() {
Ok(Self(obj.into()))
} else if let Ok(obj) = obj.extract::<Complex64>() {
Ok(Self(obj.into()))
} else {
Err(PyTypeError::new_err("Failed to extract Expression"))
}
}
}
#[pyfunction(name = "expr_sum")]
pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
if terms.is_empty() {
return Ok(PyExpression(Expression::zero()));
}
if terms.len() == 1 {
let term = &terms[0];
if let Ok(expression) = term.extract::<PyExpression>() {
return Ok(expression);
}
return Err(PyTypeError::new_err("Item is not a PyExpression"));
}
let mut iter = terms.iter();
let Some(first_term) = iter.next() else {
return Ok(PyExpression(Expression::zero()));
};
let PyExpression(mut summation) = first_term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
for term in iter {
let PyExpression(expr) = term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
summation = summation + expr;
}
Ok(PyExpression(summation))
}
#[pyfunction(name = "expr_product")]
pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
if terms.is_empty() {
return Ok(PyExpression(Expression::one()));
}
if terms.len() == 1 {
let term = &terms[0];
if let Ok(expression) = term.extract::<PyExpression>() {
return Ok(expression);
}
return Err(PyTypeError::new_err("Item is not a PyExpression"));
}
let mut iter = terms.iter();
let Some(first_term) = iter.next() else {
return Ok(PyExpression(Expression::one()));
};
let PyExpression(mut product) = first_term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
for term in iter {
let PyExpression(expr) = term
.extract::<PyExpression>()
.map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
product = product * expr;
}
Ok(PyExpression(product))
}
#[pyfunction(name = "Zero")]
pub fn py_expr_zero() -> PyExpression {
PyExpression(Expression::zero())
}
#[pyfunction(name = "One")]
pub fn py_expr_one() -> PyExpression {
PyExpression(Expression::one())
}
#[pyfunction(name = "Scalar", signature = (*tags, value))]
pub fn py_scalar(tags: &Bound<'_, PyTuple>, value: PyParameter) -> PyResult<PyExpression> {
Ok(PyExpression(Scalar::new(py_tags(tags)?, value.0)?))
}
#[pyfunction(name = "VariableScalar", signature = (*tags, variable))]
pub fn py_variable_scalar(
tags: &Bound<'_, PyTuple>,
variable: Bound<'_, PyAny>,
) -> PyResult<PyExpression> {
let variable = variable.extract::<PyVariable>()?;
Ok(PyExpression(VariableScalar::new(
py_tags(tags)?,
&variable,
)?))
}
#[pyfunction(name = "ComplexScalar", signature = (*tags, re, im))]
pub fn py_complex_scalar(
tags: &Bound<'_, PyTuple>,
re: PyParameter,
im: PyParameter,
) -> PyResult<PyExpression> {
Ok(PyExpression(ComplexScalar::new(
py_tags(tags)?,
re.0,
im.0,
)?))
}
#[pyfunction(name = "PolarComplexScalar", signature = (*tags, r, theta))]
pub fn py_polar_complex_scalar(
tags: &Bound<'_, PyTuple>,
r: PyParameter,
theta: PyParameter,
) -> PyResult<PyExpression> {
Ok(PyExpression(PolarComplexScalar::new(
py_tags(tags)?,
r.0,
theta.0,
)?))
}
#[pyfunction(name = "BreitWigner", signature = (*tags, mass, width, l, daughter_1_mass, daughter_2_mass, resonance_mass, barrier_factors=true))]
#[allow(clippy::too_many_arguments)]
pub fn py_breit_wigner(
tags: &Bound<'_, PyTuple>,
mass: PyParameter,
width: PyParameter,
l: &Bound<'_, PyAny>,
daughter_1_mass: &PyMass,
daughter_2_mass: &PyMass,
resonance_mass: &PyMass,
barrier_factors: bool,
) -> PyResult<PyExpression> {
let l = parse_orbital_angular_momentum(l)?.value() as usize;
if barrier_factors {
Ok(PyExpression(BreitWigner::new(
py_tags(tags)?,
mass.0,
width.0,
l,
&daughter_1_mass.0,
&daughter_2_mass.0,
&resonance_mass.0,
)?))
} else {
Ok(PyExpression(BreitWigner::new_without_barrier_factors(
py_tags(tags)?,
mass.0,
width.0,
l,
&daughter_1_mass.0,
&daughter_2_mass.0,
&resonance_mass.0,
)?))
}
}
#[pyfunction(name = "BreitWignerNonRelativistic", signature = (*tags, mass, width, resonance_mass))]
pub fn py_breit_wigner_non_relativistic(
tags: &Bound<'_, PyTuple>,
mass: PyParameter,
width: PyParameter,
resonance_mass: &PyMass,
) -> PyResult<PyExpression> {
Ok(PyExpression(BreitWignerNonRelativistic::new(
py_tags(tags)?,
mass.0,
width.0,
&resonance_mass.0,
)?))
}
#[pyfunction(name = "Flatte", signature = (*tags, mass, observed_channel_coupling, alternate_channel_coupling, observed_channel_daughter_masses, alternate_channel_daughter_masses, resonance_mass))]
pub fn py_flatte(
tags: &Bound<'_, PyTuple>,
mass: PyParameter,
observed_channel_coupling: PyParameter,
alternate_channel_coupling: PyParameter,
observed_channel_daughter_masses: (PyMass, PyMass),
alternate_channel_daughter_masses: (f64, f64),
resonance_mass: &PyMass,
) -> PyResult<PyExpression> {
Ok(PyExpression(Flatte::new(
py_tags(tags)?,
mass.0,
observed_channel_coupling.0,
alternate_channel_coupling.0,
(
&observed_channel_daughter_masses.0 .0,
&observed_channel_daughter_masses.1 .0,
),
alternate_channel_daughter_masses,
&resonance_mass.0,
)?))
}
#[pyfunction(name = "Voigt", signature = (*tags, mass, width, sigma, resonance_mass))]
pub fn py_voigt(
tags: &Bound<'_, PyTuple>,
mass: PyParameter,
width: PyParameter,
sigma: PyParameter,
resonance_mass: &PyMass,
) -> PyResult<PyExpression> {
Ok(PyExpression(Voigt::new(
py_tags(tags)?,
mass.0,
width.0,
sigma.0,
&resonance_mass.0,
)?))
}
#[pyfunction(name = "Ylm", signature = (*tags, l, m, angles))]
pub fn py_ylm(
tags: &Bound<'_, PyTuple>,
l: usize,
m: isize,
angles: &PyAngles,
) -> PyResult<PyExpression> {
Ok(PyExpression(Ylm::new(py_tags(tags)?, l, m, &angles.0)?))
}
#[pyfunction(name = "Zlm", signature = (*tags, l, m, r, angles, polarization))]
pub fn py_zlm(
tags: &Bound<'_, PyTuple>,
l: usize,
m: isize,
r: &str,
angles: &PyAngles,
polarization: &PyPolarization,
) -> PyResult<PyExpression> {
Ok(PyExpression(Zlm::new(
py_tags(tags)?,
l,
m,
r.parse()?,
&angles.0,
&polarization.0,
)?))
}
#[pyfunction(name = "PolPhase", signature = (*tags, polarization))]
pub fn py_polphase(
tags: &Bound<'_, PyTuple>,
polarization: &PyPolarization,
) -> PyResult<PyExpression> {
Ok(PyExpression(PolPhase::new(
py_tags(tags)?,
&polarization.0,
)?))
}
#[pyfunction(name = "WignerD", signature = (*tags, spin, row_projection, column_projection, angles))]
pub fn py_wigner_d(
tags: &Bound<'_, PyTuple>,
spin: &Bound<'_, PyAny>,
row_projection: &Bound<'_, PyAny>,
column_projection: &Bound<'_, PyAny>,
angles: &PyAngles,
) -> PyResult<PyExpression> {
Ok(PyExpression(WignerD::new(
py_tags(tags)?,
parse_angular_momentum(spin)?,
parse_projection(row_projection)?,
parse_projection(column_projection)?,
&angles.0,
)?))
}
#[pyfunction(name = "BlattWeisskopf", signature = (*tags, decay, l, reference_mass, q_r = QR_DEFAULT, sheet = "physical", kind = "full"))]
pub fn py_blatt_weisskopf(
tags: &Bound<'_, PyTuple>,
decay: &PyDecay,
l: &Bound<'_, PyAny>,
reference_mass: f64,
q_r: f64,
sheet: &str,
kind: &str,
) -> PyResult<PyExpression> {
let sheet = match sheet.to_ascii_lowercase().as_str() {
"physical" => Sheet::Physical,
"unphysical" => Sheet::Unphysical,
_ => {
return Err(PyValueError::new_err(
"sheet must be 'physical' or 'unphysical'",
));
}
};
let kind = match kind.to_ascii_lowercase().as_str() {
"full" => BarrierKind::Full,
"tensor" => BarrierKind::Tensor,
_ => {
return Err(PyValueError::new_err("kind must be 'full' or 'tensor'"));
}
};
Ok(PyExpression(BlattWeisskopf::new(
py_tags(tags)?,
&decay.0,
parse_orbital_angular_momentum(l)?,
reference_mass,
q_r,
sheet,
kind,
)?))
}
#[pyfunction(name = "ClebschGordan", signature = (*tags, j1, m1, j2, m2, j, m))]
pub fn py_clebsch_gordan(
tags: &Bound<'_, PyTuple>,
j1: &Bound<'_, PyAny>,
m1: &Bound<'_, PyAny>,
j2: &Bound<'_, PyAny>,
m2: &Bound<'_, PyAny>,
j: &Bound<'_, PyAny>,
m: &Bound<'_, PyAny>,
) -> PyResult<PyExpression> {
Ok(PyExpression(ClebschGordan::new(
py_tags(tags)?,
parse_angular_momentum(j1)?,
parse_projection(m1)?,
parse_angular_momentum(j2)?,
parse_projection(m2)?,
parse_angular_momentum(j)?,
parse_projection(m)?,
)?))
}
#[pyfunction(name = "Wigner3j", signature = (*tags, j1, m1, j2, m2, j3, m3))]
pub fn py_wigner_3j(
tags: &Bound<'_, PyTuple>,
j1: &Bound<'_, PyAny>,
m1: &Bound<'_, PyAny>,
j2: &Bound<'_, PyAny>,
m2: &Bound<'_, PyAny>,
j3: &Bound<'_, PyAny>,
m3: &Bound<'_, PyAny>,
) -> PyResult<PyExpression> {
Ok(PyExpression(Wigner3j::new(
py_tags(tags)?,
parse_angular_momentum(j1)?,
parse_projection(m1)?,
parse_angular_momentum(j2)?,
parse_projection(m2)?,
parse_angular_momentum(j3)?,
parse_projection(m3)?,
)?))
}
#[pyfunction(name = "PhotonSDME", signature = (*tags, helicity, helicity_prime, polarization = None))]
pub fn py_photon_sdme(
tags: &Bound<'_, PyTuple>,
helicity: i32,
helicity_prime: i32,
polarization: Option<&PyPolarization>,
) -> PyResult<PyExpression> {
let polarization = polarization
.map(|polarization| PhotonPolarization::Linear(Box::new(polarization.0.clone())))
.unwrap_or(PhotonPolarization::Unpolarized);
Ok(PyExpression(PhotonSDME::new(
py_tags(tags)?,
polarization,
PhotonHelicity::new(helicity)?,
PhotonHelicity::new(helicity_prime)?,
)?))
}
#[pyfunction(name = "PhaseSpaceFactor", signature = (*tags, recoil_mass, daughter_1_mass, daughter_2_mass, resonance_mass, mandelstam_s))]
pub fn py_phase_space_factor(
tags: &Bound<'_, PyTuple>,
recoil_mass: &PyMass,
daughter_1_mass: &PyMass,
daughter_2_mass: &PyMass,
resonance_mass: &PyMass,
mandelstam_s: &PyMandelstam,
) -> PyResult<PyExpression> {
Ok(PyExpression(PhaseSpaceFactor::new(
py_tags(tags)?,
&recoil_mass.0,
&daughter_1_mass.0,
&daughter_2_mass.0,
&resonance_mass.0,
&mandelstam_s.0,
)?))
}
fn py_lookup_inputs(
variables: Vec<PyVariable>,
axis_coordinates: Vec<Vec<f64>>,
) -> LadduResult<LookupInputs> {
let axis_coordinates = axis_coordinates
.into_iter()
.map(LookupAxis::new)
.collect::<LadduResult<Vec<_>>>()?;
let variables = variables
.into_iter()
.map(|variable| Box::new(variable) as Box<dyn Variable>)
.collect();
Ok((variables, axis_coordinates))
}
#[pyfunction(name = "LookupTable", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
pub fn py_lookup_table(
tags: &Bound<'_, PyTuple>,
variables: Vec<PyVariable>,
axis_coordinates: Vec<Vec<f64>>,
values: Vec<Complex64>,
interpolation: &str,
boundary_mode: &str,
) -> PyResult<PyExpression> {
let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
Ok(PyExpression(LookupTable::new(
py_tags(tags)?,
variables,
axis_coordinates,
values,
interpolation.parse()?,
boundary_mode.parse()?,
)?))
}
#[pyfunction(name = "LookupTableScalar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
pub fn py_lookup_table_scalar(
tags: &Bound<'_, PyTuple>,
variables: Vec<PyVariable>,
axis_coordinates: Vec<Vec<f64>>,
values: Vec<PyParameter>,
interpolation: &str,
boundary_mode: &str,
) -> PyResult<PyExpression> {
let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
Ok(PyExpression(LookupTable::new_scalar(
py_tags(tags)?,
variables,
axis_coordinates,
values.into_iter().map(|value| value.0).collect(),
interpolation.parse()?,
boundary_mode.parse()?,
)?))
}
#[pyfunction(name = "LookupTableComplex", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
pub fn py_lookup_table_complex(
tags: &Bound<'_, PyTuple>,
variables: Vec<PyVariable>,
axis_coordinates: Vec<Vec<f64>>,
values: Vec<(PyParameter, PyParameter)>,
interpolation: &str,
boundary_mode: &str,
) -> PyResult<PyExpression> {
let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
Ok(PyExpression(LookupTable::new_cartesian_complex(
py_tags(tags)?,
variables,
axis_coordinates,
values
.into_iter()
.map(|(value_re, value_im)| (value_re.0, value_im.0))
.collect(),
interpolation.parse()?,
boundary_mode.parse()?,
)?))
}
#[pyfunction(name = "LookupTablePolar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
pub fn py_lookup_table_polar(
tags: &Bound<'_, PyTuple>,
variables: Vec<PyVariable>,
axis_coordinates: Vec<Vec<f64>>,
values: Vec<(PyParameter, PyParameter)>,
interpolation: &str,
boundary_mode: &str,
) -> PyResult<PyExpression> {
let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
Ok(PyExpression(LookupTable::new_polar_complex(
py_tags(tags)?,
variables,
axis_coordinates,
values
.into_iter()
.map(|(value_r, value_theta)| (value_r.0, value_theta.0))
.collect(),
interpolation.parse()?,
boundary_mode.parse()?,
)?))
}
#[pyfunction(name = "KopfKMatrixA0", signature = (*tags, couplings, channel, mass, seed = None))]
pub fn py_kopf_kmatrix_a0(
tags: &Bound<'_, PyTuple>,
couplings: [[PyParameter; 2]; 2],
channel: PyKopfKMatrixA0Channel,
mass: PyMass,
seed: Option<usize>,
) -> PyResult<PyExpression> {
Ok(PyExpression(KopfKMatrixA0::new(
py_tags(tags)?,
array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
channel.into(),
&mass.0,
seed,
)?))
}
#[pyfunction(name = "KopfKMatrixA2", signature = (*tags, couplings, channel, mass, seed = None))]
pub fn py_kopf_kmatrix_a2(
tags: &Bound<'_, PyTuple>,
couplings: [[PyParameter; 2]; 2],
channel: PyKopfKMatrixA2Channel,
mass: PyMass,
seed: Option<usize>,
) -> PyResult<PyExpression> {
Ok(PyExpression(KopfKMatrixA2::new(
py_tags(tags)?,
array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
channel.into(),
&mass.0,
seed,
)?))
}
#[pyfunction(name = "KopfKMatrixF0", signature = (*tags, couplings, channel, mass, seed = None))]
pub fn py_kopf_kmatrix_f0(
tags: &Bound<'_, PyTuple>,
couplings: [[PyParameter; 2]; 5],
channel: PyKopfKMatrixF0Channel,
mass: PyMass,
seed: Option<usize>,
) -> PyResult<PyExpression> {
Ok(PyExpression(KopfKMatrixF0::new(
py_tags(tags)?,
array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
channel.into(),
&mass.0,
seed,
)?))
}
#[pyfunction(name = "KopfKMatrixF2", signature = (*tags, couplings, channel, mass, seed = None))]
pub fn py_kopf_kmatrix_f2(
tags: &Bound<'_, PyTuple>,
couplings: [[PyParameter; 2]; 4],
channel: PyKopfKMatrixF2Channel,
mass: PyMass,
seed: Option<usize>,
) -> PyResult<PyExpression> {
Ok(PyExpression(KopfKMatrixF2::new(
py_tags(tags)?,
array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
channel.into(),
&mass.0,
seed,
)?))
}
#[pyfunction(name = "KopfKMatrixPi1", signature = (*tags, couplings, channel, mass))]
pub fn py_kopf_kmatrix_pi1(
tags: &Bound<'_, PyTuple>,
couplings: [[PyParameter; 2]; 1],
channel: PyKopfKMatrixPi1Channel,
mass: PyMass,
) -> PyResult<PyExpression> {
Ok(PyExpression(KopfKMatrixPi1::new(
py_tags(tags)?,
array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
channel.into(),
&mass.0,
)?))
}
#[pyfunction(name = "KopfKMatrixRho", signature = (*tags, couplings, channel, mass))]
pub fn py_kopf_kmatrix_rho(
tags: &Bound<'_, PyTuple>,
couplings: [[PyParameter; 2]; 2],
channel: PyKopfKMatrixRhoChannel,
mass: PyMass,
) -> PyResult<PyExpression> {
Ok(PyExpression(KopfKMatrixRho::new(
py_tags(tags)?,
array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
channel.into(),
&mass.0,
)?))
}
#[pymethods]
impl PyExpression {
#[getter]
fn parameters(&self) -> PyParameterMap {
PyParameterMap(self.0.parameters())
}
#[getter]
fn n_free(&self) -> usize {
self.0.n_free()
}
#[getter]
fn n_fixed(&self) -> usize {
self.0.n_fixed()
}
#[getter]
fn n_parameters(&self) -> usize {
self.0.n_parameters()
}
fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
Ok(PyEvaluator(self.0.load(&dataset.0)?))
}
fn real(&self) -> PyExpression {
PyExpression(self.0.real())
}
fn imag(&self) -> PyExpression {
PyExpression(self.0.imag())
}
fn conj(&self) -> PyExpression {
PyExpression(self.0.conj())
}
fn norm_sqr(&self) -> PyExpression {
PyExpression(self.0.norm_sqr())
}
fn sqrt(&self) -> PyExpression {
PyExpression(self.0.sqrt())
}
fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(value) = power.extract::<i32>() {
Ok(PyExpression(self.0.powi(value)))
} else if let Ok(value) = power.extract::<f64>() {
Ok(PyExpression(self.0.powf(value)))
} else if let Ok(expression) = power.extract::<PyExpression>() {
Ok(PyExpression(self.0.pow(&expression.0)))
} else {
Err(PyTypeError::new_err(
"power must be an int, float, or Expression",
))
}
}
fn exp(&self) -> PyExpression {
PyExpression(self.0.exp())
}
fn sin(&self) -> PyExpression {
PyExpression(self.0.sin())
}
fn cos(&self) -> PyExpression {
PyExpression(self.0.cos())
}
fn log(&self) -> PyExpression {
PyExpression(self.0.log())
}
fn cis(&self) -> PyExpression {
PyExpression(self.0.cis())
}
fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
Ok(self.0.fix_parameter(name, value)?)
}
fn free_parameter(&self, name: &str) -> PyResult<()> {
Ok(self.0.free_parameter(name)?)
}
fn rename_parameter(&mut self, old: &str, new: &str) -> PyResult<()> {
Ok(self.0.rename_parameter(old, new)?)
}
fn rename_parameters(&mut self, mapping: HashMap<String, String>) -> PyResult<()> {
Ok(self.0.rename_parameters(&mapping)?)
}
#[getter]
fn compiled_expression(&self) -> PyCompiledExpression {
PyCompiledExpression(self.0.compiled_expression())
}
fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() + other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 + self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for +"))
}
}
fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() - other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for -"))
}
}
fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 - self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for -"))
}
}
fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() * other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 * self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for *"))
}
}
fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(self.0.clone() / other_expr.0))
} else {
Err(PyTypeError::new_err("Unsupported operand type for /"))
}
}
fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
if let Ok(other_expr) = other.extract::<PyExpression>() {
Ok(PyExpression(other_expr.0 / self.0.clone()))
} else {
Err(PyTypeError::new_err("Unsupported operand type for /"))
}
}
fn __neg__(&self) -> PyExpression {
PyExpression(-self.0.clone())
}
fn __str__(&self) -> String {
format!("{}", self.0)
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
#[new]
fn new() -> Self {
Self(Expression::default())
}
fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
Ok(PyBytes::new(
py,
serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
.map_err(LadduError::PickleError)?
.as_slice(),
))
}
fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
*self = Self(
serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
.map_err(LadduError::PickleError)?,
);
Ok(())
}
}
#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyEvaluator(pub Evaluator);
#[pymethods]
impl PyEvaluator {
#[getter]
fn parameters(&self) -> PyParameterMap {
PyParameterMap(self.0.parameters())
}
#[getter]
fn n_free(&self) -> usize {
self.0.n_free()
}
#[getter]
fn n_fixed(&self) -> usize {
self.0.n_fixed()
}
#[getter]
fn n_parameters(&self) -> usize {
self.0.n_parameters()
}
fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
Ok(self.0.fix_parameter(name, value)?)
}
fn free_parameter(&self, name: &str) -> PyResult<()> {
Ok(self.0.free_parameter(name)?)
}
fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
Ok(self.0.rename_parameter(old, new)?)
}
fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
Ok(self.0.rename_parameters(&mapping)?)
}
#[pyo3(signature = (arg, *, strict=true))]
fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.activate_strict(&string_arg)?;
} else {
self.0.activate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.activate_many_strict(&vec)?;
} else {
self.0.activate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
fn activate_all(&self) {
self.0.activate_all();
}
#[pyo3(signature = (arg, *, strict=true))]
fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.deactivate_strict(&string_arg)?;
} else {
self.0.deactivate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.deactivate_many_strict(&vec)?;
} else {
self.0.deactivate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
fn deactivate_all(&self) {
self.0.deactivate_all();
}
#[pyo3(signature = (arg, *, strict=true))]
fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
if let Ok(string_arg) = arg.extract::<String>() {
if strict {
self.0.isolate_strict(&string_arg)?;
} else {
self.0.isolate(&string_arg);
}
} else if let Ok(list_arg) = arg.cast::<PyList>() {
let vec: Vec<String> = list_arg.extract()?;
if strict {
self.0.isolate_many_strict(&vec)?;
} else {
self.0.isolate_many(&vec);
}
} else {
return Err(PyTypeError::new_err(
"Argument must be either a string or a list of strings",
));
}
Ok(())
}
#[getter]
fn active_mask(&self) -> Vec<bool> {
self.0.active_mask()
}
fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
self.0.set_active_mask(&mask)?;
Ok(())
}
#[getter]
fn compiled_expression(&self) -> PyCompiledExpression {
PyCompiledExpression(self.0.compiled_expression())
}
#[getter]
fn expression(&self) -> PyExpression {
PyExpression(self.0.expression())
}
#[pyo3(signature = (parameters, *, threads=None))]
fn evaluate<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
let values = install_with_threads(threads, || self.0.evaluate(¶meters))?;
Ok(PyArray1::from_slice(py, &values?))
}
#[pyo3(signature = (parameters, indices, *, threads=None))]
fn evaluate_batch<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
indices: Vec<usize>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
let values =
install_with_threads(threads, || self.0.evaluate_batch(¶meters, &indices))?;
Ok(PyArray1::from_slice(py, &values?))
}
#[pyo3(signature = (parameters, *, threads=None))]
fn evaluate_gradient<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
let gradients: LadduResult<_> = install_with_threads(threads, || {
Ok(self
.0
.evaluate_gradient(¶meters)?
.iter()
.map(|grad| grad.data.as_vec().to_vec())
.collect::<Vec<Vec<Complex64>>>())
})?;
Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
}
#[pyo3(signature = (parameters, indices, *, threads=None))]
fn evaluate_gradient_batch<'py>(
&self,
py: Python<'py>,
parameters: Vec<f64>,
indices: Vec<usize>,
threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
let gradients: LadduResult<_> = install_with_threads(threads, || {
Ok(self
.0
.evaluate_gradient_batch(¶meters, &indices)?
.iter()
.map(|grad| grad.data.as_vec().to_vec())
.collect::<Vec<Vec<Complex64>>>())
})?;
Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
}
}
#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyCompiledExpression(pub CompiledExpression);
#[pymethods]
impl PyCompiledExpression {
fn __str__(&self) -> String {
format!("{}", self.0)
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pyclass(name = "ParameterMap", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyParameterMap(pub ParameterMap);
#[pymethods]
impl PyParameterMap {
#[getter]
fn names<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
PyTuple::new(py, self.0.names())
}
#[getter]
fn free(&self) -> PyParameterMap {
PyParameterMap(self.0.free())
}
#[getter]
fn fixed(&self) -> PyParameterMap {
PyParameterMap(self.0.fixed())
}
fn contains(&self, name: &str) -> bool {
self.0.contains_key(name)
}
fn __contains__(&self, name: &str) -> bool {
self.contains(name)
}
fn __len__(&self) -> usize {
self.0.len()
}
fn __bool__(&self) -> bool {
!self.0.is_empty()
}
fn __getitem__(&self, index: &Bound<'_, PyAny>) -> PyResult<PyParameter> {
if let Ok(name) = index.extract::<String>() {
return self
.0
.get(&name)
.cloned()
.map(PyParameter)
.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err(name));
}
if let Ok(index) = index.extract::<isize>() {
let len = self.0.len() as isize;
let normalized = if index < 0 { len + index } else { index };
if normalized < 0 || normalized >= len {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"parameter index out of range: {index}"
)));
}
return Ok(PyParameter(self.0[normalized as usize].clone()));
}
Err(pyo3::exceptions::PyTypeError::new_err(
"ParameterMap indices must be str or int",
))
}
fn __str__(&self) -> String {
self.0.to_string()
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
fn index(&self, name: &str) -> PyResult<usize> {
self.0
.index(name)
.ok_or_else(|| PyValueError::new_err(format!("parameter not found: {name}")))
}
fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
let params: Vec<PyParameter> = self.0.clone().into_iter().map(PyParameter).collect();
let list = PyList::new(py, params)?;
PyIterator::from_object(&list)
}
}
#[pyclass(name = "Parameter", module = "laddu", from_py_object)]
#[derive(Clone)]
pub struct PyParameter(pub Parameter);
#[pymethods]
impl PyParameter {
#[getter]
fn name(&self) -> String {
self.0.name()
}
#[getter]
fn fixed(&self) -> Option<f64> {
self.0.fixed()
}
#[getter]
fn initial(&self) -> Option<f64> {
self.0.initial()
}
#[getter]
fn bounds(&self) -> (Option<f64>, Option<f64>) {
self.0.bounds()
}
#[getter]
fn unit(&self) -> Option<String> {
self.0.unit()
}
#[getter]
fn latex(&self) -> Option<String> {
self.0.latex()
}
#[getter]
fn description(&self) -> Option<String> {
self.0.description()
}
}
#[pyfunction(name = "parameter", signature = (name, fixed=None, *, initial=None, bounds=(None, None), unit=None, latex=None, description=None))]
pub fn py_parameter(
name: &str,
fixed: Option<f64>,
initial: Option<f64>,
bounds: (Option<f64>, Option<f64>),
unit: Option<&str>,
latex: Option<&str>,
description: Option<&str>,
) -> PyParameter {
let par = Parameter::new(name);
if let Some(value) = initial {
par.set_initial(value);
}
if let Some(value) = fixed {
par.set_fixed_value(Some(value)); }
par.set_bounds(bounds.0, bounds.1);
if let Some(unit) = unit {
par.set_unit(unit);
}
if let Some(latex) = latex {
par.set_latex(latex);
}
if let Some(description) = description {
par.set_description(description);
}
PyParameter(par)
}
#[pyfunction(name = "TestAmplitude", signature = (*tags, re, im))]
pub fn py_test_amplitude(
tags: &Bound<'_, PyTuple>,
re: PyParameter,
im: PyParameter,
) -> PyResult<PyExpression> {
Ok(PyExpression(TestAmplitude::new(
py_tags(tags)?,
re.0,
im.0,
)?))
}