Skip to main content

laddu_python/
amplitudes.rs

1use std::{array, collections::HashMap};
2
3use laddu_amplitudes::{
4    angular::{
5        BlattWeisskopf, ClebschGordan, PhotonHelicity, PhotonPolarization, PhotonSDME, PolPhase,
6        Wigner3j, WignerD, Ylm, Zlm,
7    },
8    kmatrix::{
9        KopfKMatrixA0, KopfKMatrixA0Channel, KopfKMatrixA2, KopfKMatrixA2Channel, KopfKMatrixF0,
10        KopfKMatrixF0Channel, KopfKMatrixF2, KopfKMatrixF2Channel, KopfKMatrixPi1,
11        KopfKMatrixPi1Channel, KopfKMatrixRho, KopfKMatrixRhoChannel,
12    },
13    lookup::{LookupAxis, LookupTable},
14    resonance::{BreitWigner, BreitWignerNonRelativistic, Flatte, PhaseSpaceFactor, Voigt},
15    scalar::{ComplexScalar, PolarComplexScalar, Scalar, VariableScalar},
16};
17use laddu_core::{
18    amplitude::{Evaluator, Expression, Parameter, ParameterMap, TestAmplitude},
19    math::{BarrierKind, Sheet, QR_DEFAULT},
20    traits::Variable,
21    CompiledExpression, LadduError, LadduResult, ThreadPoolManager,
22};
23use num::complex::Complex64;
24use numpy::{PyArray1, PyArray2};
25use pyo3::{
26    exceptions::{PyTypeError, PyValueError},
27    prelude::*,
28    types::{PyAny, PyBytes, PyIterator, PyList, PyTuple},
29};
30
31use crate::{
32    data::PyDataset,
33    quantum::angular_momentum::{
34        parse_angular_momentum, parse_orbital_angular_momentum, parse_projection,
35    },
36    variables::{PyAngles, PyDecay, PyMandelstam, PyMass, PyPolarization, PyVariable},
37};
38
39type LookupInputs = (Vec<Box<dyn Variable>>, Vec<LookupAxis>);
40
41macro_rules! py_kmatrix_channel {
42    ($py_name:ident, $python_name:literal, $rust_name:path { $($variant:ident),+ $(,)? }) => {
43        #[pyclass(eq, name = $python_name, module = "laddu", from_py_object)]
44        #[derive(Clone, PartialEq)]
45        pub enum $py_name {
46            $($variant,)+
47        }
48
49        impl From<$py_name> for $rust_name {
50            fn from(value: $py_name) -> Self {
51                match value {
52                    $( $py_name::$variant => Self::$variant, )+
53                }
54            }
55        }
56    };
57}
58
59py_kmatrix_channel!(
60    PyKopfKMatrixA0Channel,
61    "KopfKMatrixA0Channel",
62    KopfKMatrixA0Channel { PiEta, KKbar }
63);
64py_kmatrix_channel!(
65    PyKopfKMatrixA2Channel,
66    "KopfKMatrixA2Channel",
67    KopfKMatrixA2Channel {
68        PiEta,
69        KKbar,
70        PiEtaPrime
71    }
72);
73py_kmatrix_channel!(
74    PyKopfKMatrixF0Channel,
75    "KopfKMatrixF0Channel",
76    KopfKMatrixF0Channel {
77        PiPi,
78        FourPi,
79        KKbar,
80        EtaEta,
81        EtaEtaPrime
82    }
83);
84py_kmatrix_channel!(
85    PyKopfKMatrixF2Channel,
86    "KopfKMatrixF2Channel",
87    KopfKMatrixF2Channel {
88        PiPi,
89        FourPi,
90        KKbar,
91        EtaEta
92    }
93);
94py_kmatrix_channel!(
95    PyKopfKMatrixPi1Channel,
96    "KopfKMatrixPi1Channel",
97    KopfKMatrixPi1Channel { PiEta, PiEtaPrime }
98);
99py_kmatrix_channel!(
100    PyKopfKMatrixRhoChannel,
101    "KopfKMatrixRhoChannel",
102    KopfKMatrixRhoChannel {
103        PiPi,
104        FourPi,
105        KKbar
106    }
107);
108
109fn install_with_threads<R: Send>(
110    threads: Option<usize>,
111    op: impl FnOnce() -> R + Send,
112) -> LadduResult<R> {
113    ThreadPoolManager::shared().install(threads, op)
114}
115
116pub(crate) fn py_tags(tags: &Bound<'_, PyTuple>) -> PyResult<Vec<String>> {
117    tags.iter()
118        .map(|tag| tag.extract::<String>())
119        .collect::<PyResult<Vec<_>>>()
120}
121
122/// A mathematical expression formed from amplitudes.
123///
124#[pyclass(name = "Expression", module = "laddu", skip_from_py_object)]
125#[derive(Clone)]
126pub struct PyExpression(pub Expression);
127
128impl<'py> FromPyObject<'_, 'py> for PyExpression {
129    type Error = PyErr;
130
131    fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
132        if let Ok(obj) = obj.cast::<PyExpression>() {
133            Ok(obj.borrow().clone())
134        } else if let Ok(obj) = obj.extract::<f64>() {
135            Ok(Self(obj.into()))
136        } else if let Ok(obj) = obj.extract::<Complex64>() {
137            Ok(Self(obj.into()))
138        } else {
139            Err(PyTypeError::new_err("Failed to extract Expression"))
140        }
141    }
142}
143
144/// A convenience method to sum sequences of Expressions
145///
146#[pyfunction(name = "expr_sum")]
147pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
148    if terms.is_empty() {
149        return Ok(PyExpression(Expression::zero()));
150    }
151    if terms.len() == 1 {
152        let term = &terms[0];
153        if let Ok(expression) = term.extract::<PyExpression>() {
154            return Ok(expression);
155        }
156        return Err(PyTypeError::new_err("Item is not a PyExpression"));
157    }
158    let mut iter = terms.iter();
159    let Some(first_term) = iter.next() else {
160        return Ok(PyExpression(Expression::zero()));
161    };
162    let PyExpression(mut summation) = first_term
163        .extract::<PyExpression>()
164        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
165    for term in iter {
166        let PyExpression(expr) = term
167            .extract::<PyExpression>()
168            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
169        summation = summation + expr;
170    }
171    Ok(PyExpression(summation))
172}
173
174/// A convenience method to multiply sequences of Expressions
175///
176#[pyfunction(name = "expr_product")]
177pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
178    if terms.is_empty() {
179        return Ok(PyExpression(Expression::one()));
180    }
181    if terms.len() == 1 {
182        let term = &terms[0];
183        if let Ok(expression) = term.extract::<PyExpression>() {
184            return Ok(expression);
185        }
186        return Err(PyTypeError::new_err("Item is not a PyExpression"));
187    }
188    let mut iter = terms.iter();
189    let Some(first_term) = iter.next() else {
190        return Ok(PyExpression(Expression::one()));
191    };
192    let PyExpression(mut product) = first_term
193        .extract::<PyExpression>()
194        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
195    for term in iter {
196        let PyExpression(expr) = term
197            .extract::<PyExpression>()
198            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
199        product = product * expr;
200    }
201    Ok(PyExpression(product))
202}
203
204/// A convenience class representing a zero-valued Expression
205///
206#[pyfunction(name = "Zero")]
207pub fn py_expr_zero() -> PyExpression {
208    PyExpression(Expression::zero())
209}
210
211/// A convenience class representing a unit-valued Expression
212///
213#[pyfunction(name = "One")]
214pub fn py_expr_one() -> PyExpression {
215    PyExpression(Expression::one())
216}
217
218/// Construct a scalar amplitude from one parameter.
219#[pyfunction(name = "Scalar", signature = (*tags, value))]
220pub fn py_scalar(tags: &Bound<'_, PyTuple>, value: PyParameter) -> PyResult<PyExpression> {
221    Ok(PyExpression(Scalar::new(py_tags(tags)?, value.0)?))
222}
223
224/// Construct a real expression from an event variable.
225#[pyfunction(name = "VariableScalar", signature = (*tags, variable))]
226pub fn py_variable_scalar(
227    tags: &Bound<'_, PyTuple>,
228    variable: Bound<'_, PyAny>,
229) -> PyResult<PyExpression> {
230    let variable = variable.extract::<PyVariable>()?;
231    Ok(PyExpression(VariableScalar::new(
232        py_tags(tags)?,
233        &variable,
234    )?))
235}
236
237/// Construct a cartesian complex scalar amplitude.
238#[pyfunction(name = "ComplexScalar", signature = (*tags, re, im))]
239pub fn py_complex_scalar(
240    tags: &Bound<'_, PyTuple>,
241    re: PyParameter,
242    im: PyParameter,
243) -> PyResult<PyExpression> {
244    Ok(PyExpression(ComplexScalar::new(
245        py_tags(tags)?,
246        re.0,
247        im.0,
248    )?))
249}
250
251/// Construct a polar complex scalar amplitude.
252#[pyfunction(name = "PolarComplexScalar", signature = (*tags, r, theta))]
253pub fn py_polar_complex_scalar(
254    tags: &Bound<'_, PyTuple>,
255    r: PyParameter,
256    theta: PyParameter,
257) -> PyResult<PyExpression> {
258    Ok(PyExpression(PolarComplexScalar::new(
259        py_tags(tags)?,
260        r.0,
261        theta.0,
262    )?))
263}
264
265/// Construct a relativistic Breit-Wigner amplitude.
266#[pyfunction(name = "BreitWigner", signature = (*tags, mass, width, l, daughter_1_mass, daughter_2_mass, resonance_mass, barrier_factors=true))]
267#[allow(clippy::too_many_arguments)]
268pub fn py_breit_wigner(
269    tags: &Bound<'_, PyTuple>,
270    mass: PyParameter,
271    width: PyParameter,
272    l: usize,
273    daughter_1_mass: &PyMass,
274    daughter_2_mass: &PyMass,
275    resonance_mass: &PyMass,
276    barrier_factors: bool,
277) -> PyResult<PyExpression> {
278    if barrier_factors {
279        Ok(PyExpression(BreitWigner::new(
280            py_tags(tags)?,
281            mass.0,
282            width.0,
283            l,
284            &daughter_1_mass.0,
285            &daughter_2_mass.0,
286            &resonance_mass.0,
287        )?))
288    } else {
289        Ok(PyExpression(BreitWigner::new_without_barrier_factors(
290            py_tags(tags)?,
291            mass.0,
292            width.0,
293            l,
294            &daughter_1_mass.0,
295            &daughter_2_mass.0,
296            &resonance_mass.0,
297        )?))
298    }
299}
300
301/// Construct a non-relativistic Breit-Wigner amplitude.
302#[pyfunction(name = "BreitWignerNonRelativistic", signature = (*tags, mass, width, resonance_mass))]
303pub fn py_breit_wigner_non_relativistic(
304    tags: &Bound<'_, PyTuple>,
305    mass: PyParameter,
306    width: PyParameter,
307    resonance_mass: &PyMass,
308) -> PyResult<PyExpression> {
309    Ok(PyExpression(BreitWignerNonRelativistic::new(
310        py_tags(tags)?,
311        mass.0,
312        width.0,
313        &resonance_mass.0,
314    )?))
315}
316
317/// Construct a Flatte amplitude.
318#[pyfunction(name = "Flatte", signature = (*tags, mass, observed_channel_coupling, alternate_channel_coupling, observed_channel_daughter_masses, alternate_channel_daughter_masses, resonance_mass))]
319pub fn py_flatte(
320    tags: &Bound<'_, PyTuple>,
321    mass: PyParameter,
322    observed_channel_coupling: PyParameter,
323    alternate_channel_coupling: PyParameter,
324    observed_channel_daughter_masses: (PyMass, PyMass),
325    alternate_channel_daughter_masses: (f64, f64),
326    resonance_mass: &PyMass,
327) -> PyResult<PyExpression> {
328    Ok(PyExpression(Flatte::new(
329        py_tags(tags)?,
330        mass.0,
331        observed_channel_coupling.0,
332        alternate_channel_coupling.0,
333        (
334            &observed_channel_daughter_masses.0 .0,
335            &observed_channel_daughter_masses.1 .0,
336        ),
337        alternate_channel_daughter_masses,
338        &resonance_mass.0,
339    )?))
340}
341
342/// Construct a Voigt amplitude.
343#[pyfunction(name = "Voigt", signature = (*tags, mass, width, sigma, resonance_mass))]
344pub fn py_voigt(
345    tags: &Bound<'_, PyTuple>,
346    mass: PyParameter,
347    width: PyParameter,
348    sigma: PyParameter,
349    resonance_mass: &PyMass,
350) -> PyResult<PyExpression> {
351    Ok(PyExpression(Voigt::new(
352        py_tags(tags)?,
353        mass.0,
354        width.0,
355        sigma.0,
356        &resonance_mass.0,
357    )?))
358}
359
360/// Construct a spherical-harmonic amplitude.
361#[pyfunction(name = "Ylm", signature = (*tags, l, m, angles))]
362pub fn py_ylm(
363    tags: &Bound<'_, PyTuple>,
364    l: usize,
365    m: isize,
366    angles: &PyAngles,
367) -> PyResult<PyExpression> {
368    Ok(PyExpression(Ylm::new(py_tags(tags)?, l, m, &angles.0)?))
369}
370
371/// Construct a polarized spherical-harmonic amplitude.
372#[pyfunction(name = "Zlm", signature = (*tags, l, m, r, angles, polarization))]
373pub fn py_zlm(
374    tags: &Bound<'_, PyTuple>,
375    l: usize,
376    m: isize,
377    r: &str,
378    angles: &PyAngles,
379    polarization: &PyPolarization,
380) -> PyResult<PyExpression> {
381    Ok(PyExpression(Zlm::new(
382        py_tags(tags)?,
383        l,
384        m,
385        r.parse()?,
386        &angles.0,
387        &polarization.0,
388    )?))
389}
390
391/// Construct a polarization phase amplitude.
392#[pyfunction(name = "PolPhase", signature = (*tags, polarization))]
393pub fn py_polphase(
394    tags: &Bound<'_, PyTuple>,
395    polarization: &PyPolarization,
396) -> PyResult<PyExpression> {
397    Ok(PyExpression(PolPhase::new(
398        py_tags(tags)?,
399        &polarization.0,
400    )?))
401}
402
403/// Construct a Wigner-D amplitude.
404#[pyfunction(name = "WignerD", signature = (*tags, spin, row_projection, column_projection, angles))]
405pub fn py_wigner_d(
406    tags: &Bound<'_, PyTuple>,
407    spin: &Bound<'_, PyAny>,
408    row_projection: &Bound<'_, PyAny>,
409    column_projection: &Bound<'_, PyAny>,
410    angles: &PyAngles,
411) -> PyResult<PyExpression> {
412    Ok(PyExpression(WignerD::new(
413        py_tags(tags)?,
414        parse_angular_momentum(spin)?,
415        parse_projection(row_projection)?,
416        parse_projection(column_projection)?,
417        &angles.0,
418    )?))
419}
420
421/// Construct a Blatt-Weisskopf amplitude.
422#[pyfunction(name = "BlattWeisskopf", signature = (*tags, decay, l, reference_mass, q_r = QR_DEFAULT, sheet = "physical", kind = "full"))]
423pub fn py_blatt_weisskopf(
424    tags: &Bound<'_, PyTuple>,
425    decay: &PyDecay,
426    l: &Bound<'_, PyAny>,
427    reference_mass: f64,
428    q_r: f64,
429    sheet: &str,
430    kind: &str,
431) -> PyResult<PyExpression> {
432    let sheet = match sheet.to_ascii_lowercase().as_str() {
433        "physical" => Sheet::Physical,
434        "unphysical" => Sheet::Unphysical,
435        _ => {
436            return Err(PyValueError::new_err(
437                "sheet must be 'physical' or 'unphysical'",
438            ));
439        }
440    };
441    let kind = match kind.to_ascii_lowercase().as_str() {
442        "full" => BarrierKind::Full,
443        "tensor" => BarrierKind::Tensor,
444        _ => {
445            return Err(PyValueError::new_err("kind must be 'full' or 'tensor'"));
446        }
447    };
448    Ok(PyExpression(BlattWeisskopf::new(
449        py_tags(tags)?,
450        &decay.0,
451        parse_orbital_angular_momentum(l)?,
452        reference_mass,
453        q_r,
454        sheet,
455        kind,
456    )?))
457}
458
459/// Construct a Clebsch-Gordan constant expression.
460#[pyfunction(name = "ClebschGordan", signature = (*tags, j1, m1, j2, m2, j, m))]
461pub fn py_clebsch_gordan(
462    tags: &Bound<'_, PyTuple>,
463    j1: &Bound<'_, PyAny>,
464    m1: &Bound<'_, PyAny>,
465    j2: &Bound<'_, PyAny>,
466    m2: &Bound<'_, PyAny>,
467    j: &Bound<'_, PyAny>,
468    m: &Bound<'_, PyAny>,
469) -> PyResult<PyExpression> {
470    Ok(PyExpression(ClebschGordan::new(
471        py_tags(tags)?,
472        parse_angular_momentum(j1)?,
473        parse_projection(m1)?,
474        parse_angular_momentum(j2)?,
475        parse_projection(m2)?,
476        parse_angular_momentum(j)?,
477        parse_projection(m)?,
478    )?))
479}
480
481/// Construct a Wigner-3j constant expression.
482#[pyfunction(name = "Wigner3j", signature = (*tags, j1, m1, j2, m2, j3, m3))]
483pub fn py_wigner_3j(
484    tags: &Bound<'_, PyTuple>,
485    j1: &Bound<'_, PyAny>,
486    m1: &Bound<'_, PyAny>,
487    j2: &Bound<'_, PyAny>,
488    m2: &Bound<'_, PyAny>,
489    j3: &Bound<'_, PyAny>,
490    m3: &Bound<'_, PyAny>,
491) -> PyResult<PyExpression> {
492    Ok(PyExpression(Wigner3j::new(
493        py_tags(tags)?,
494        parse_angular_momentum(j1)?,
495        parse_projection(m1)?,
496        parse_angular_momentum(j2)?,
497        parse_projection(m2)?,
498        parse_angular_momentum(j3)?,
499        parse_projection(m3)?,
500    )?))
501}
502
503/// Construct a photon SDME amplitude.
504#[pyfunction(name = "PhotonSDME", signature = (*tags, helicity, helicity_prime, polarization = None))]
505pub fn py_photon_sdme(
506    tags: &Bound<'_, PyTuple>,
507    helicity: i32,
508    helicity_prime: i32,
509    polarization: Option<&PyPolarization>,
510) -> PyResult<PyExpression> {
511    let polarization = polarization
512        .map(|polarization| PhotonPolarization::Linear(Box::new(polarization.0.clone())))
513        .unwrap_or(PhotonPolarization::Unpolarized);
514    Ok(PyExpression(PhotonSDME::new(
515        py_tags(tags)?,
516        polarization,
517        PhotonHelicity::new(helicity)?,
518        PhotonHelicity::new(helicity_prime)?,
519    )?))
520}
521
522/// Construct a phase-space factor amplitude.
523#[pyfunction(name = "PhaseSpaceFactor", signature = (*tags, recoil_mass, daughter_1_mass, daughter_2_mass, resonance_mass, mandelstam_s))]
524pub fn py_phase_space_factor(
525    tags: &Bound<'_, PyTuple>,
526    recoil_mass: &PyMass,
527    daughter_1_mass: &PyMass,
528    daughter_2_mass: &PyMass,
529    resonance_mass: &PyMass,
530    mandelstam_s: &PyMandelstam,
531) -> PyResult<PyExpression> {
532    Ok(PyExpression(PhaseSpaceFactor::new(
533        py_tags(tags)?,
534        &recoil_mass.0,
535        &daughter_1_mass.0,
536        &daughter_2_mass.0,
537        &resonance_mass.0,
538        &mandelstam_s.0,
539    )?))
540}
541
542fn py_lookup_inputs(
543    variables: Vec<PyVariable>,
544    axis_coordinates: Vec<Vec<f64>>,
545) -> LadduResult<LookupInputs> {
546    let axis_coordinates = axis_coordinates
547        .into_iter()
548        .map(LookupAxis::new)
549        .collect::<LadduResult<Vec<_>>>()?;
550    let variables = variables
551        .into_iter()
552        .map(|variable| Box::new(variable) as Box<dyn Variable>)
553        .collect();
554    Ok((variables, axis_coordinates))
555}
556
557/// Construct a fixed-complex lookup table amplitude.
558#[pyfunction(name = "LookupTable", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
559pub fn py_lookup_table(
560    tags: &Bound<'_, PyTuple>,
561    variables: Vec<PyVariable>,
562    axis_coordinates: Vec<Vec<f64>>,
563    values: Vec<Complex64>,
564    interpolation: &str,
565    boundary_mode: &str,
566) -> PyResult<PyExpression> {
567    let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
568    Ok(PyExpression(LookupTable::new(
569        py_tags(tags)?,
570        variables,
571        axis_coordinates,
572        values,
573        interpolation.parse()?,
574        boundary_mode.parse()?,
575    )?))
576}
577
578/// Construct a scalar-parameter lookup table amplitude.
579#[pyfunction(name = "LookupTableScalar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
580pub fn py_lookup_table_scalar(
581    tags: &Bound<'_, PyTuple>,
582    variables: Vec<PyVariable>,
583    axis_coordinates: Vec<Vec<f64>>,
584    values: Vec<PyParameter>,
585    interpolation: &str,
586    boundary_mode: &str,
587) -> PyResult<PyExpression> {
588    let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
589    Ok(PyExpression(LookupTable::new_scalar(
590        py_tags(tags)?,
591        variables,
592        axis_coordinates,
593        values.into_iter().map(|value| value.0).collect(),
594        interpolation.parse()?,
595        boundary_mode.parse()?,
596    )?))
597}
598
599/// Construct a cartesian-complex lookup table amplitude.
600#[pyfunction(name = "LookupTableComplex", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
601pub fn py_lookup_table_complex(
602    tags: &Bound<'_, PyTuple>,
603    variables: Vec<PyVariable>,
604    axis_coordinates: Vec<Vec<f64>>,
605    values: Vec<(PyParameter, PyParameter)>,
606    interpolation: &str,
607    boundary_mode: &str,
608) -> PyResult<PyExpression> {
609    let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
610    Ok(PyExpression(LookupTable::new_cartesian_complex(
611        py_tags(tags)?,
612        variables,
613        axis_coordinates,
614        values
615            .into_iter()
616            .map(|(value_re, value_im)| (value_re.0, value_im.0))
617            .collect(),
618        interpolation.parse()?,
619        boundary_mode.parse()?,
620    )?))
621}
622
623/// Construct a polar-complex lookup table amplitude.
624#[pyfunction(name = "LookupTablePolar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
625pub fn py_lookup_table_polar(
626    tags: &Bound<'_, PyTuple>,
627    variables: Vec<PyVariable>,
628    axis_coordinates: Vec<Vec<f64>>,
629    values: Vec<(PyParameter, PyParameter)>,
630    interpolation: &str,
631    boundary_mode: &str,
632) -> PyResult<PyExpression> {
633    let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
634    Ok(PyExpression(LookupTable::new_polar_complex(
635        py_tags(tags)?,
636        variables,
637        axis_coordinates,
638        values
639            .into_iter()
640            .map(|(value_r, value_theta)| (value_r.0, value_theta.0))
641            .collect(),
642        interpolation.parse()?,
643        boundary_mode.parse()?,
644    )?))
645}
646
647/// Construct the fixed Kopf `a0` K-matrix amplitude.
648#[pyfunction(name = "KopfKMatrixA0", signature = (*tags, couplings, channel, mass, seed = None))]
649pub fn py_kopf_kmatrix_a0(
650    tags: &Bound<'_, PyTuple>,
651    couplings: [[PyParameter; 2]; 2],
652    channel: PyKopfKMatrixA0Channel,
653    mass: PyMass,
654    seed: Option<usize>,
655) -> PyResult<PyExpression> {
656    Ok(PyExpression(KopfKMatrixA0::new(
657        py_tags(tags)?,
658        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
659        channel.into(),
660        &mass.0,
661        seed,
662    )?))
663}
664
665/// Construct the fixed Kopf `a2` K-matrix amplitude.
666#[pyfunction(name = "KopfKMatrixA2", signature = (*tags, couplings, channel, mass, seed = None))]
667pub fn py_kopf_kmatrix_a2(
668    tags: &Bound<'_, PyTuple>,
669    couplings: [[PyParameter; 2]; 2],
670    channel: PyKopfKMatrixA2Channel,
671    mass: PyMass,
672    seed: Option<usize>,
673) -> PyResult<PyExpression> {
674    Ok(PyExpression(KopfKMatrixA2::new(
675        py_tags(tags)?,
676        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
677        channel.into(),
678        &mass.0,
679        seed,
680    )?))
681}
682
683/// Construct the fixed Kopf `f0` K-matrix amplitude.
684#[pyfunction(name = "KopfKMatrixF0", signature = (*tags, couplings, channel, mass, seed = None))]
685pub fn py_kopf_kmatrix_f0(
686    tags: &Bound<'_, PyTuple>,
687    couplings: [[PyParameter; 2]; 5],
688    channel: PyKopfKMatrixF0Channel,
689    mass: PyMass,
690    seed: Option<usize>,
691) -> PyResult<PyExpression> {
692    Ok(PyExpression(KopfKMatrixF0::new(
693        py_tags(tags)?,
694        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
695        channel.into(),
696        &mass.0,
697        seed,
698    )?))
699}
700
701/// Construct the fixed Kopf `f2` K-matrix amplitude.
702#[pyfunction(name = "KopfKMatrixF2", signature = (*tags, couplings, channel, mass, seed = None))]
703pub fn py_kopf_kmatrix_f2(
704    tags: &Bound<'_, PyTuple>,
705    couplings: [[PyParameter; 2]; 4],
706    channel: PyKopfKMatrixF2Channel,
707    mass: PyMass,
708    seed: Option<usize>,
709) -> PyResult<PyExpression> {
710    Ok(PyExpression(KopfKMatrixF2::new(
711        py_tags(tags)?,
712        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
713        channel.into(),
714        &mass.0,
715        seed,
716    )?))
717}
718
719/// Construct the fixed Kopf `pi1` K-matrix amplitude.
720#[pyfunction(name = "KopfKMatrixPi1", signature = (*tags, couplings, channel, mass))]
721pub fn py_kopf_kmatrix_pi1(
722    tags: &Bound<'_, PyTuple>,
723    couplings: [[PyParameter; 2]; 1],
724    channel: PyKopfKMatrixPi1Channel,
725    mass: PyMass,
726) -> PyResult<PyExpression> {
727    Ok(PyExpression(KopfKMatrixPi1::new(
728        py_tags(tags)?,
729        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
730        channel.into(),
731        &mass.0,
732    )?))
733}
734
735/// Construct the fixed Kopf `rho` K-matrix amplitude.
736#[pyfunction(name = "KopfKMatrixRho", signature = (*tags, couplings, channel, mass))]
737pub fn py_kopf_kmatrix_rho(
738    tags: &Bound<'_, PyTuple>,
739    couplings: [[PyParameter; 2]; 2],
740    channel: PyKopfKMatrixRhoChannel,
741    mass: PyMass,
742) -> PyResult<PyExpression> {
743    Ok(PyExpression(KopfKMatrixRho::new(
744        py_tags(tags)?,
745        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
746        channel.into(),
747        &mass.0,
748    )?))
749}
750
751#[pymethods]
752impl PyExpression {
753    /// The parameters used by the Expression.
754    ///
755    /// Returns
756    /// -------
757    /// parameters : ParameterMap
758    ///     The parameter map.
759    #[getter]
760    fn parameters(&self) -> PyParameterMap {
761        PyParameterMap(self.0.parameters())
762    }
763    /// Number of free parameters
764    #[getter]
765    fn n_free(&self) -> usize {
766        self.0.n_free()
767    }
768    /// Number of fixed parameters
769    #[getter]
770    fn n_fixed(&self) -> usize {
771        self.0.n_fixed()
772    }
773    /// Total number of parameters
774    #[getter]
775    fn n_parameters(&self) -> usize {
776        self.0.n_parameters()
777    }
778    /// Load an Expression by precalculating each term over the given Dataset
779    ///
780    /// Parameters
781    /// ----------
782    /// dataset : Dataset
783    ///     The Dataset to use in precalculation
784    ///
785    /// Returns
786    /// -------
787    /// Evaluator
788    ///     An object that can be used to evaluate the `expression` over each event in the
789    ///     `dataset`
790    fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
791        Ok(PyEvaluator(self.0.load(&dataset.0)?))
792    }
793    /// The real part of a complex Expression
794    fn real(&self) -> PyExpression {
795        PyExpression(self.0.real())
796    }
797    /// The imaginary part of a complex Expression
798    fn imag(&self) -> PyExpression {
799        PyExpression(self.0.imag())
800    }
801    /// The complex conjugate of a complex Expression
802    fn conj(&self) -> PyExpression {
803        PyExpression(self.0.conj())
804    }
805    /// The norm-squared of a complex Expression
806    fn norm_sqr(&self) -> PyExpression {
807        PyExpression(self.0.norm_sqr())
808    }
809    /// The square root of an Expression
810    fn sqrt(&self) -> PyExpression {
811        PyExpression(self.0.sqrt())
812    }
813    /// Raise an Expression to an int, float, or Expression power
814    fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
815        if let Ok(value) = power.extract::<i32>() {
816            Ok(PyExpression(self.0.powi(value)))
817        } else if let Ok(value) = power.extract::<f64>() {
818            Ok(PyExpression(self.0.powf(value)))
819        } else if let Ok(expression) = power.extract::<PyExpression>() {
820            Ok(PyExpression(self.0.pow(&expression.0)))
821        } else {
822            Err(PyTypeError::new_err(
823                "power must be an int, float, or Expression",
824            ))
825        }
826    }
827    /// The exponential of an Expression
828    fn exp(&self) -> PyExpression {
829        PyExpression(self.0.exp())
830    }
831    /// The sine of an Expression
832    fn sin(&self) -> PyExpression {
833        PyExpression(self.0.sin())
834    }
835    /// The cosine of an Expression
836    fn cos(&self) -> PyExpression {
837        PyExpression(self.0.cos())
838    }
839    /// The natural logarithm of an Expression
840    fn log(&self) -> PyExpression {
841        PyExpression(self.0.log())
842    }
843    /// The complex phase factor exp(i * expression)
844    fn cis(&self) -> PyExpression {
845        PyExpression(self.0.cis())
846    }
847    /// Fix a parameter used by this Expression.
848    fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
849        Ok(self.0.fix_parameter(name, value)?)
850    }
851    /// Mark a parameter used by this Expression as free.
852    fn free_parameter(&self, name: &str) -> PyResult<()> {
853        Ok(self.0.free_parameter(name)?)
854    }
855    /// Rename a single parameter used by this Expression.
856    fn rename_parameter(&mut self, old: &str, new: &str) -> PyResult<()> {
857        Ok(self.0.rename_parameter(old, new)?)
858    }
859    /// Rename several parameters used by this Expression.
860    fn rename_parameters(&mut self, mapping: HashMap<String, String>) -> PyResult<()> {
861        Ok(self.0.rename_parameters(&mapping)?)
862    }
863    /// Return a tree-like diagnostic view of the compiled Expression.
864    #[getter]
865    fn compiled_expression(&self) -> PyCompiledExpression {
866        PyCompiledExpression(self.0.compiled_expression())
867    }
868    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
869        if let Ok(other_expr) = other.extract::<PyExpression>() {
870            Ok(PyExpression(self.0.clone() + other_expr.0))
871        } else {
872            Err(PyTypeError::new_err("Unsupported operand type for +"))
873        }
874    }
875    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
876        if let Ok(other_expr) = other.extract::<PyExpression>() {
877            Ok(PyExpression(other_expr.0 + self.0.clone()))
878        } else {
879            Err(PyTypeError::new_err("Unsupported operand type for +"))
880        }
881    }
882    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
883        if let Ok(other_expr) = other.extract::<PyExpression>() {
884            Ok(PyExpression(self.0.clone() - other_expr.0))
885        } else {
886            Err(PyTypeError::new_err("Unsupported operand type for -"))
887        }
888    }
889    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
890        if let Ok(other_expr) = other.extract::<PyExpression>() {
891            Ok(PyExpression(other_expr.0 - self.0.clone()))
892        } else {
893            Err(PyTypeError::new_err("Unsupported operand type for -"))
894        }
895    }
896    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
897        if let Ok(other_expr) = other.extract::<PyExpression>() {
898            Ok(PyExpression(self.0.clone() * other_expr.0))
899        } else {
900            Err(PyTypeError::new_err("Unsupported operand type for *"))
901        }
902    }
903    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
904        if let Ok(other_expr) = other.extract::<PyExpression>() {
905            Ok(PyExpression(other_expr.0 * self.0.clone()))
906        } else {
907            Err(PyTypeError::new_err("Unsupported operand type for *"))
908        }
909    }
910    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
911        if let Ok(other_expr) = other.extract::<PyExpression>() {
912            Ok(PyExpression(self.0.clone() / other_expr.0))
913        } else {
914            Err(PyTypeError::new_err("Unsupported operand type for /"))
915        }
916    }
917    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
918        if let Ok(other_expr) = other.extract::<PyExpression>() {
919            Ok(PyExpression(other_expr.0 / self.0.clone()))
920        } else {
921            Err(PyTypeError::new_err("Unsupported operand type for /"))
922        }
923    }
924    fn __neg__(&self) -> PyExpression {
925        PyExpression(-self.0.clone())
926    }
927    fn __str__(&self) -> String {
928        format!("{}", self.0)
929    }
930    fn __repr__(&self) -> String {
931        format!("{:?}", self.0)
932    }
933
934    #[new]
935    fn new() -> Self {
936        Self(Expression::default())
937    }
938    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
939        Ok(PyBytes::new(
940            py,
941            serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
942                .map_err(LadduError::PickleError)?
943                .as_slice(),
944        ))
945    }
946    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
947        *self = Self(
948            serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
949                .map_err(LadduError::PickleError)?,
950        );
951        Ok(())
952    }
953}
954
955/// A class which can be used to evaluate a stored Expression
956///
957/// See Also
958/// --------
959/// laddu.Expression.load
960///
961#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
962#[derive(Clone)]
963pub struct PyEvaluator(pub Evaluator);
964
965#[pymethods]
966impl PyEvaluator {
967    /// The free parameters used by the Evaluator
968    ///
969    /// Returns
970    /// -------
971    /// parameters : ParameterMap
972    ///     The parameter map.
973    ///
974    #[getter]
975    fn parameters(&self) -> PyParameterMap {
976        PyParameterMap(self.0.parameters())
977    }
978    /// Number of free parameters
979    #[getter]
980    fn n_free(&self) -> usize {
981        self.0.n_free()
982    }
983    /// Number of fixed parameters
984    #[getter]
985    fn n_fixed(&self) -> usize {
986        self.0.n_fixed()
987    }
988    /// Total number of parameters
989    #[getter]
990    fn n_parameters(&self) -> usize {
991        self.0.n_parameters()
992    }
993    /// Fix a parameter used by this Evaluator.
994    fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
995        Ok(self.0.fix_parameter(name, value)?)
996    }
997    /// Mark a parameter used by this Evaluator as free.
998    fn free_parameter(&self, name: &str) -> PyResult<()> {
999        Ok(self.0.free_parameter(name)?)
1000    }
1001    /// Rename a single parameter used by this Evaluator.
1002    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
1003        Ok(self.0.rename_parameter(old, new)?)
1004    }
1005    /// Rename several parameters used by this Evaluator.
1006    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
1007        Ok(self.0.rename_parameters(&mapping)?)
1008    }
1009    /// Activates Amplitude use-sites in the Expression by tag or glob selector.
1010    ///
1011    /// Parameters
1012    /// ----------
1013    /// arg : str or list of str
1014    ///     Tags or ``*``/``?`` glob selectors of Amplitudes to be activated
1015    ///
1016    /// Raises
1017    /// ------
1018    /// TypeError
1019    ///     If `arg` is not a str or list of str
1020    /// ValueError
1021    ///     If `arg` or any items of `arg` match no tagged Amplitudes
1022    /// strict : bool, default=True
1023    ///     When ``True``, raise an error if any selector matches no amplitudes. When
1024    ///     ``False``, silently skip selectors with no matches.
1025    #[pyo3(signature = (arg, *, strict=true))]
1026    fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1027        if let Ok(string_arg) = arg.extract::<String>() {
1028            if strict {
1029                self.0.activate_strict(&string_arg)?;
1030            } else {
1031                self.0.activate(&string_arg);
1032            }
1033        } else if let Ok(list_arg) = arg.cast::<PyList>() {
1034            let vec: Vec<String> = list_arg.extract()?;
1035            if strict {
1036                self.0.activate_many_strict(&vec)?;
1037            } else {
1038                self.0.activate_many(&vec);
1039            }
1040        } else {
1041            return Err(PyTypeError::new_err(
1042                "Argument must be either a string or a list of strings",
1043            ));
1044        }
1045        Ok(())
1046    }
1047    /// Activates all Amplitude use-sites in the Expression.
1048    ///
1049    fn activate_all(&self) {
1050        self.0.activate_all();
1051    }
1052    /// Deactivates Amplitude use-sites in the Expression by tag or glob selector.
1053    ///
1054    /// Deactivated Amplitudes act as zeros in the Expression
1055    ///
1056    /// Parameters
1057    /// ----------
1058    /// arg : str or list of str
1059    ///     Tags or ``*``/``?`` glob selectors of Amplitudes to be deactivated
1060    ///
1061    /// Raises
1062    /// ------
1063    /// TypeError
1064    ///     If `arg` is not a str or list of str
1065    /// ValueError
1066    ///     If `arg` or any items of `arg` match no tagged Amplitudes
1067    /// strict : bool, default=True
1068    ///     When ``True``, raise an error if any selector matches no amplitudes. When
1069    ///     ``False``, silently skip selectors with no matches.
1070    #[pyo3(signature = (arg, *, strict=true))]
1071    fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1072        if let Ok(string_arg) = arg.extract::<String>() {
1073            if strict {
1074                self.0.deactivate_strict(&string_arg)?;
1075            } else {
1076                self.0.deactivate(&string_arg);
1077            }
1078        } else if let Ok(list_arg) = arg.cast::<PyList>() {
1079            let vec: Vec<String> = list_arg.extract()?;
1080            if strict {
1081                self.0.deactivate_many_strict(&vec)?;
1082            } else {
1083                self.0.deactivate_many(&vec);
1084            }
1085        } else {
1086            return Err(PyTypeError::new_err(
1087                "Argument must be either a string or a list of strings",
1088            ));
1089        }
1090        Ok(())
1091    }
1092    /// Deactivates all tagged Amplitude use-sites in the Expression.
1093    ///
1094    fn deactivate_all(&self) {
1095        self.0.deactivate_all();
1096    }
1097    /// Isolates Amplitude use-sites in the Expression by tag or glob selector.
1098    ///
1099    /// Activates the tagged Amplitudes given in `arg` and deactivates the rest.
1100    ///
1101    /// Parameters
1102    /// ----------
1103    /// arg : str or list of str
1104    ///     Tags or ``*``/``?`` glob selectors of Amplitudes to be isolated
1105    ///
1106    /// Raises
1107    /// ------
1108    /// TypeError
1109    ///     If `arg` is not a str or list of str
1110    /// ValueError
1111    ///     If `arg` or any items of `arg` match no tagged Amplitudes
1112    /// strict : bool, default=True
1113    ///     When ``True``, raise an error if any selector matches no amplitudes. When
1114    ///     ``False``, silently skip selectors with no matches.
1115    #[pyo3(signature = (arg, *, strict=true))]
1116    fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1117        if let Ok(string_arg) = arg.extract::<String>() {
1118            if strict {
1119                self.0.isolate_strict(&string_arg)?;
1120            } else {
1121                self.0.isolate(&string_arg);
1122            }
1123        } else if let Ok(list_arg) = arg.cast::<PyList>() {
1124            let vec: Vec<String> = list_arg.extract()?;
1125            if strict {
1126                self.0.isolate_many_strict(&vec)?;
1127            } else {
1128                self.0.isolate_many(&vec);
1129            }
1130        } else {
1131            return Err(PyTypeError::new_err(
1132                "Argument must be either a string or a list of strings",
1133            ));
1134        }
1135        Ok(())
1136    }
1137
1138    /// Return the current active-amplitude mask.
1139    #[getter]
1140    fn active_mask(&self) -> Vec<bool> {
1141        self.0.active_mask()
1142    }
1143
1144    /// Apply an active-amplitude mask.
1145    fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
1146        self.0.set_active_mask(&mask)?;
1147        Ok(())
1148    }
1149
1150    /// Return a tree-like diagnostic view of the compiled Expression.
1151    #[getter]
1152    fn compiled_expression(&self) -> PyCompiledExpression {
1153        PyCompiledExpression(self.0.compiled_expression())
1154    }
1155
1156    /// Return the Expression represented by this Evaluator.
1157    #[getter]
1158    fn expression(&self) -> PyExpression {
1159        PyExpression(self.0.expression())
1160    }
1161
1162    /// Evaluate the stored Expression over the stored Dataset
1163    ///
1164    /// Parameters
1165    /// ----------
1166    /// parameters : list of float
1167    ///     The values to use for the free parameters
1168    /// threads : int, optional
1169    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
1170    ///     global or context-managed default; any positive value overrides that default for
1171    ///     this call only)
1172    ///
1173    /// Returns
1174    /// -------
1175    /// result : array_like
1176    ///     A ``numpy`` array of complex values for each Event in the Dataset
1177    ///
1178    /// Raises
1179    /// ------
1180    /// Exception
1181    ///     If there was an error building the thread pool
1182    ///
1183    #[pyo3(signature = (parameters, *, threads=None))]
1184    fn evaluate<'py>(
1185        &self,
1186        py: Python<'py>,
1187        parameters: Vec<f64>,
1188        threads: Option<usize>,
1189    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
1190        let values = install_with_threads(threads, || self.0.evaluate(&parameters))?;
1191        Ok(PyArray1::from_slice(py, &values?))
1192    }
1193    /// Evaluate the stored Expression over a subset of the stored Dataset
1194    ///
1195    /// Parameters
1196    /// ----------
1197    /// parameters : list of float
1198    ///     The values to use for the free parameters
1199    /// indices : list of int
1200    ///     The indices of events to evaluate
1201    /// threads : int, optional
1202    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
1203    ///     global or context-managed default; any positive value overrides that default for
1204    ///     this call only)
1205    ///
1206    /// Returns
1207    /// -------
1208    /// result : array_like
1209    ///     A ``numpy`` array of complex values for each indexed Event in the Dataset
1210    ///
1211    /// Raises
1212    /// ------
1213    /// Exception
1214    ///     If there was an error building the thread pool
1215    ///
1216    #[pyo3(signature = (parameters, indices, *, threads=None))]
1217    fn evaluate_batch<'py>(
1218        &self,
1219        py: Python<'py>,
1220        parameters: Vec<f64>,
1221        indices: Vec<usize>,
1222        threads: Option<usize>,
1223    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
1224        let values =
1225            install_with_threads(threads, || self.0.evaluate_batch(&parameters, &indices))?;
1226        Ok(PyArray1::from_slice(py, &values?))
1227    }
1228    /// Evaluate the gradient of the stored Expression over the stored Dataset
1229    ///
1230    /// Parameters
1231    /// ----------
1232    /// parameters : list of float
1233    ///     The values to use for the free parameters
1234    /// threads : int, optional
1235    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
1236    ///     global or context-managed default; any positive value overrides that default for
1237    ///     this call only)
1238    ///
1239    /// Returns
1240    /// -------
1241    /// result : array_like
1242    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
1243    ///
1244    /// Raises
1245    /// ------
1246    /// Exception
1247    ///     If there was an error building the thread pool or problem creating the resulting
1248    ///     ``numpy`` array
1249    ///
1250    #[pyo3(signature = (parameters, *, threads=None))]
1251    fn evaluate_gradient<'py>(
1252        &self,
1253        py: Python<'py>,
1254        parameters: Vec<f64>,
1255        threads: Option<usize>,
1256    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
1257        let gradients: LadduResult<_> = install_with_threads(threads, || {
1258            Ok(self
1259                .0
1260                .evaluate_gradient(&parameters)?
1261                .iter()
1262                .map(|grad| grad.data.as_vec().to_vec())
1263                .collect::<Vec<Vec<Complex64>>>())
1264        })?;
1265        Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
1266    }
1267    /// Evaluate the gradient of the stored Expression over a subset of the stored Dataset
1268    ///
1269    /// Parameters
1270    /// ----------
1271    /// parameters : list of float
1272    ///     The values to use for the free parameters
1273    /// indices : list of int
1274    ///     The indices of events to evaluate
1275    /// threads : int, optional
1276    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
1277    ///     global or context-managed default; any positive value overrides that default for
1278    ///     this call only)
1279    ///
1280    /// Returns
1281    /// -------
1282    /// result : array_like
1283    ///     A ``numpy`` 2D array of complex values for each indexed Event in the Dataset
1284    ///
1285    /// Raises
1286    /// ------
1287    /// Exception
1288    ///     If there was an error building the thread pool or problem creating the resulting
1289    ///     ``numpy`` array
1290    ///
1291    #[pyo3(signature = (parameters, indices, *, threads=None))]
1292    fn evaluate_gradient_batch<'py>(
1293        &self,
1294        py: Python<'py>,
1295        parameters: Vec<f64>,
1296        indices: Vec<usize>,
1297        threads: Option<usize>,
1298    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
1299        let gradients: LadduResult<_> = install_with_threads(threads, || {
1300            Ok(self
1301                .0
1302                .evaluate_gradient_batch(&parameters, &indices)?
1303                .iter()
1304                .map(|grad| grad.data.as_vec().to_vec())
1305                .collect::<Vec<Vec<Complex64>>>())
1306        })?;
1307        Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
1308    }
1309}
1310
1311/// A class which can be used to display the compiled form of an Expression
1312///
1313/// Notes
1314/// -----
1315/// This should not be used for anything other than diagnostic purposes.
1316///
1317#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
1318#[derive(Clone)]
1319pub struct PyCompiledExpression(pub CompiledExpression);
1320
1321#[pymethods]
1322impl PyCompiledExpression {
1323    fn __str__(&self) -> String {
1324        format!("{}", self.0)
1325    }
1326    fn __repr__(&self) -> String {
1327        format!("{:?}", self.0)
1328    }
1329}
1330
1331#[pyclass(name = "ParameterMap", module = "laddu", from_py_object)]
1332#[derive(Clone)]
1333pub struct PyParameterMap(pub ParameterMap);
1334
1335#[pymethods]
1336impl PyParameterMap {
1337    #[getter]
1338    fn names<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
1339        PyTuple::new(py, self.0.names())
1340    }
1341
1342    #[getter]
1343    fn free(&self) -> PyParameterMap {
1344        PyParameterMap(self.0.free())
1345    }
1346
1347    #[getter]
1348    fn fixed(&self) -> PyParameterMap {
1349        PyParameterMap(self.0.fixed())
1350    }
1351
1352    fn contains(&self, name: &str) -> bool {
1353        self.0.contains_key(name)
1354    }
1355
1356    fn __contains__(&self, name: &str) -> bool {
1357        self.contains(name)
1358    }
1359
1360    fn __len__(&self) -> usize {
1361        self.0.len()
1362    }
1363
1364    fn __bool__(&self) -> bool {
1365        !self.0.is_empty()
1366    }
1367
1368    fn __getitem__(&self, index: &Bound<'_, PyAny>) -> PyResult<PyParameter> {
1369        if let Ok(name) = index.extract::<String>() {
1370            return self
1371                .0
1372                .get(&name)
1373                .cloned()
1374                .map(PyParameter)
1375                .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err(name));
1376        }
1377        if let Ok(index) = index.extract::<isize>() {
1378            let len = self.0.len() as isize;
1379            let normalized = if index < 0 { len + index } else { index };
1380            if normalized < 0 || normalized >= len {
1381                return Err(pyo3::exceptions::PyIndexError::new_err(format!(
1382                    "parameter index out of range: {index}"
1383                )));
1384            }
1385            return Ok(PyParameter(self.0[normalized as usize].clone()));
1386        }
1387        Err(pyo3::exceptions::PyTypeError::new_err(
1388            "ParameterMap indices must be str or int",
1389        ))
1390    }
1391
1392    fn __str__(&self) -> String {
1393        self.0.to_string()
1394    }
1395
1396    fn __repr__(&self) -> String {
1397        format!("{:?}", self.0)
1398    }
1399
1400    fn index(&self, name: &str) -> PyResult<usize> {
1401        self.0
1402            .index(name)
1403            .ok_or_else(|| PyValueError::new_err(format!("parameter not found: {name}")))
1404    }
1405
1406    fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
1407        let params: Vec<PyParameter> = self.0.clone().into_iter().map(PyParameter).collect();
1408        let list = PyList::new(py, params)?;
1409        PyIterator::from_object(&list)
1410    }
1411}
1412
1413#[pyclass(name = "Parameter", module = "laddu", from_py_object)]
1414#[derive(Clone)]
1415pub struct PyParameter(pub Parameter);
1416
1417#[pymethods]
1418impl PyParameter {
1419    #[getter]
1420    fn name(&self) -> String {
1421        self.0.name()
1422    }
1423    #[getter]
1424    fn fixed(&self) -> Option<f64> {
1425        self.0.fixed()
1426    }
1427    #[getter]
1428    fn initial(&self) -> Option<f64> {
1429        self.0.initial()
1430    }
1431    #[getter]
1432    fn bounds(&self) -> (Option<f64>, Option<f64>) {
1433        self.0.bounds()
1434    }
1435    #[getter]
1436    fn unit(&self) -> Option<String> {
1437        self.0.unit()
1438    }
1439    #[getter]
1440    fn latex(&self) -> Option<String> {
1441        self.0.latex()
1442    }
1443    #[getter]
1444    fn description(&self) -> Option<String> {
1445        self.0.description()
1446    }
1447}
1448
1449/// A free parameter which floats during an optimization
1450///
1451/// Parameters
1452/// ----------
1453/// name : str
1454///     The name of the free parameter
1455/// fixed : float, optional
1456///     If specified, the parameter will be fixed to this value
1457/// initial : float, optional
1458///     If specified, the parameter will always be initialized to this value
1459/// bounds : tuple of (float or None, float or None)
1460///     Specify the lower and upper bounds for the parameter (None corresponds to no bound)
1461/// unit : str, optional
1462///     Optional unit string which may be used to annotate the parameter
1463/// latex : str, optional
1464///     Optional LaTeX representation of the parameter
1465/// description : str, optional
1466///     Optional description of the parameter
1467///
1468/// Returns
1469/// -------
1470/// laddu.Parameter
1471///     An object that can be used as the input for many Amplitude constructors
1472///
1473/// Notes
1474/// -----
1475/// Two free parameters with the same name are shared in a fit.
1476///
1477/// Attempting to set both the fixed and initial value will result in an overwrite (both will be
1478/// set to the "fixed" value).
1479///
1480#[pyfunction(name = "parameter", signature = (name, fixed=None, *, initial=None, bounds=(None, None), unit=None, latex=None, description=None))]
1481pub fn py_parameter(
1482    name: &str,
1483    fixed: Option<f64>,
1484    initial: Option<f64>,
1485    bounds: (Option<f64>, Option<f64>),
1486    unit: Option<&str>,
1487    latex: Option<&str>,
1488    description: Option<&str>,
1489) -> PyParameter {
1490    let par = Parameter::new(name);
1491    if let Some(value) = initial {
1492        par.set_initial(value);
1493    }
1494    if let Some(value) = fixed {
1495        par.set_fixed_value(Some(value)); // TODO: make this all consistent
1496    }
1497    par.set_bounds(bounds.0, bounds.1);
1498    if let Some(unit) = unit {
1499        par.set_unit(unit);
1500    }
1501    if let Some(latex) = latex {
1502        par.set_latex(latex);
1503    }
1504    if let Some(description) = description {
1505        par.set_description(description);
1506    }
1507    PyParameter(par)
1508}
1509
1510/// An amplitude used only for internal testing which evaluates `(p0 + i * p1) * event.p4s\[0\].e`.
1511#[pyfunction(name = "TestAmplitude", signature = (*tags, re, im))]
1512pub fn py_test_amplitude(
1513    tags: &Bound<'_, PyTuple>,
1514    re: PyParameter,
1515    im: PyParameter,
1516) -> PyResult<PyExpression> {
1517    Ok(PyExpression(TestAmplitude::new(
1518        py_tags(tags)?,
1519        re.0,
1520        im.0,
1521    )?))
1522}