Skip to main content

laddu_python/
amplitudes.rs

1use std::{array, collections::HashMap};
2
3use laddu_amplitudes::{
4    angular::{
5        BlattWeisskopf, ClebschGordan, PhotonHelicity, PhotonPolarization, PhotonSDME, PolPhase,
6        Wigner3j, WignerD, Ylm, Zlm,
7    },
8    kmatrix::{
9        KopfKMatrixA0, KopfKMatrixA0Channel, KopfKMatrixA2, KopfKMatrixA2Channel, KopfKMatrixF0,
10        KopfKMatrixF0Channel, KopfKMatrixF2, KopfKMatrixF2Channel, KopfKMatrixPi1,
11        KopfKMatrixPi1Channel, KopfKMatrixRho, KopfKMatrixRhoChannel,
12    },
13    lookup::{LookupAxis, LookupTable},
14    resonance::{BreitWigner, BreitWignerNonRelativistic, Flatte, PhaseSpaceFactor, Voigt},
15    scalar::{ComplexScalar, PolarComplexScalar, Scalar, VariableScalar},
16};
17use laddu_core::{
18    amplitude::{Evaluator, Expression, Parameter, ParameterMap, TestAmplitude},
19    math::{BarrierKind, Sheet, QR_DEFAULT},
20    traits::Variable,
21    CompiledExpression, LadduError, LadduResult, ThreadPoolManager,
22};
23use num::complex::Complex64;
24use numpy::{PyArray1, PyArray2};
25use pyo3::{
26    exceptions::{PyTypeError, PyValueError},
27    prelude::*,
28    types::{PyAny, PyBytes, PyIterator, PyList, PyTuple},
29};
30
31use crate::{
32    data::PyDataset,
33    quantum::angular_momentum::{
34        parse_angular_momentum, parse_orbital_angular_momentum, parse_projection,
35    },
36    variables::{PyAngles, PyDecay, PyMandelstam, PyMass, PyPolarization, PyVariable},
37};
38
39type LookupInputs = (Vec<Box<dyn Variable>>, Vec<LookupAxis>);
40
41macro_rules! py_kmatrix_channel {
42    ($py_name:ident, $python_name:literal, $rust_name:path { $($variant:ident),+ $(,)? }) => {
43        #[pyclass(eq, name = $python_name, module = "laddu", from_py_object)]
44        #[derive(Clone, PartialEq)]
45        pub enum $py_name {
46            $($variant,)+
47        }
48
49        impl From<$py_name> for $rust_name {
50            fn from(value: $py_name) -> Self {
51                match value {
52                    $( $py_name::$variant => Self::$variant, )+
53                }
54            }
55        }
56    };
57}
58
59py_kmatrix_channel!(
60    PyKopfKMatrixA0Channel,
61    "KopfKMatrixA0Channel",
62    KopfKMatrixA0Channel { PiEta, KKbar }
63);
64py_kmatrix_channel!(
65    PyKopfKMatrixA2Channel,
66    "KopfKMatrixA2Channel",
67    KopfKMatrixA2Channel {
68        PiEta,
69        KKbar,
70        PiEtaPrime
71    }
72);
73py_kmatrix_channel!(
74    PyKopfKMatrixF0Channel,
75    "KopfKMatrixF0Channel",
76    KopfKMatrixF0Channel {
77        PiPi,
78        FourPi,
79        KKbar,
80        EtaEta,
81        EtaEtaPrime
82    }
83);
84py_kmatrix_channel!(
85    PyKopfKMatrixF2Channel,
86    "KopfKMatrixF2Channel",
87    KopfKMatrixF2Channel {
88        PiPi,
89        FourPi,
90        KKbar,
91        EtaEta
92    }
93);
94py_kmatrix_channel!(
95    PyKopfKMatrixPi1Channel,
96    "KopfKMatrixPi1Channel",
97    KopfKMatrixPi1Channel { PiEta, PiEtaPrime }
98);
99py_kmatrix_channel!(
100    PyKopfKMatrixRhoChannel,
101    "KopfKMatrixRhoChannel",
102    KopfKMatrixRhoChannel {
103        PiPi,
104        FourPi,
105        KKbar
106    }
107);
108
109fn install_with_threads<R: Send>(
110    threads: Option<usize>,
111    op: impl FnOnce() -> R + Send,
112) -> LadduResult<R> {
113    ThreadPoolManager::shared().install(threads, op)
114}
115
116pub(crate) fn py_tags(tags: &Bound<'_, PyTuple>) -> PyResult<Vec<String>> {
117    tags.iter()
118        .map(|tag| tag.extract::<String>())
119        .collect::<PyResult<Vec<_>>>()
120}
121
122/// A mathematical expression formed from amplitudes.
123///
124#[pyclass(name = "Expression", module = "laddu", skip_from_py_object)]
125#[derive(Clone)]
126pub struct PyExpression(pub Expression);
127
128impl<'py> FromPyObject<'_, 'py> for PyExpression {
129    type Error = PyErr;
130
131    fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
132        if let Ok(obj) = obj.cast::<PyExpression>() {
133            Ok(obj.borrow().clone())
134        } else if let Ok(obj) = obj.extract::<f64>() {
135            Ok(Self(obj.into()))
136        } else if let Ok(obj) = obj.extract::<Complex64>() {
137            Ok(Self(obj.into()))
138        } else {
139            Err(PyTypeError::new_err("Failed to extract Expression"))
140        }
141    }
142}
143
144/// A convenience method to sum sequences of Expressions
145///
146#[pyfunction(name = "expr_sum")]
147pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
148    if terms.is_empty() {
149        return Ok(PyExpression(Expression::zero()));
150    }
151    if terms.len() == 1 {
152        let term = &terms[0];
153        if let Ok(expression) = term.extract::<PyExpression>() {
154            return Ok(expression);
155        }
156        return Err(PyTypeError::new_err("Item is not a PyExpression"));
157    }
158    let mut iter = terms.iter();
159    let Some(first_term) = iter.next() else {
160        return Ok(PyExpression(Expression::zero()));
161    };
162    let PyExpression(mut summation) = first_term
163        .extract::<PyExpression>()
164        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
165    for term in iter {
166        let PyExpression(expr) = term
167            .extract::<PyExpression>()
168            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
169        summation = summation + expr;
170    }
171    Ok(PyExpression(summation))
172}
173
174/// A convenience method to multiply sequences of Expressions
175///
176#[pyfunction(name = "expr_product")]
177pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
178    if terms.is_empty() {
179        return Ok(PyExpression(Expression::one()));
180    }
181    if terms.len() == 1 {
182        let term = &terms[0];
183        if let Ok(expression) = term.extract::<PyExpression>() {
184            return Ok(expression);
185        }
186        return Err(PyTypeError::new_err("Item is not a PyExpression"));
187    }
188    let mut iter = terms.iter();
189    let Some(first_term) = iter.next() else {
190        return Ok(PyExpression(Expression::one()));
191    };
192    let PyExpression(mut product) = first_term
193        .extract::<PyExpression>()
194        .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
195    for term in iter {
196        let PyExpression(expr) = term
197            .extract::<PyExpression>()
198            .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
199        product = product * expr;
200    }
201    Ok(PyExpression(product))
202}
203
204/// A convenience class representing a zero-valued Expression
205///
206#[pyfunction(name = "Zero")]
207pub fn py_expr_zero() -> PyExpression {
208    PyExpression(Expression::zero())
209}
210
211/// A convenience class representing a unit-valued Expression
212///
213#[pyfunction(name = "One")]
214pub fn py_expr_one() -> PyExpression {
215    PyExpression(Expression::one())
216}
217
218/// Construct a scalar amplitude from one parameter.
219#[pyfunction(name = "Scalar", signature = (*tags, value))]
220pub fn py_scalar(tags: &Bound<'_, PyTuple>, value: PyParameter) -> PyResult<PyExpression> {
221    Ok(PyExpression(Scalar::new(py_tags(tags)?, value.0)?))
222}
223
224/// Construct a real expression from an event variable.
225#[pyfunction(name = "VariableScalar", signature = (*tags, variable))]
226pub fn py_variable_scalar(
227    tags: &Bound<'_, PyTuple>,
228    variable: Bound<'_, PyAny>,
229) -> PyResult<PyExpression> {
230    let variable = variable.extract::<PyVariable>()?;
231    Ok(PyExpression(VariableScalar::new(
232        py_tags(tags)?,
233        &variable,
234    )?))
235}
236
237/// Construct a cartesian complex scalar amplitude.
238#[pyfunction(name = "ComplexScalar", signature = (*tags, re, im))]
239pub fn py_complex_scalar(
240    tags: &Bound<'_, PyTuple>,
241    re: PyParameter,
242    im: PyParameter,
243) -> PyResult<PyExpression> {
244    Ok(PyExpression(ComplexScalar::new(
245        py_tags(tags)?,
246        re.0,
247        im.0,
248    )?))
249}
250
251/// Construct a polar complex scalar amplitude.
252#[pyfunction(name = "PolarComplexScalar", signature = (*tags, r, theta))]
253pub fn py_polar_complex_scalar(
254    tags: &Bound<'_, PyTuple>,
255    r: PyParameter,
256    theta: PyParameter,
257) -> PyResult<PyExpression> {
258    Ok(PyExpression(PolarComplexScalar::new(
259        py_tags(tags)?,
260        r.0,
261        theta.0,
262    )?))
263}
264
265/// Construct a relativistic Breit-Wigner amplitude.
266#[pyfunction(name = "BreitWigner", signature = (*tags, mass, width, l, daughter_1_mass, daughter_2_mass, resonance_mass, barrier_factors=true))]
267#[allow(clippy::too_many_arguments)]
268pub fn py_breit_wigner(
269    tags: &Bound<'_, PyTuple>,
270    mass: PyParameter,
271    width: PyParameter,
272    l: &Bound<'_, PyAny>,
273    daughter_1_mass: &PyMass,
274    daughter_2_mass: &PyMass,
275    resonance_mass: &PyMass,
276    barrier_factors: bool,
277) -> PyResult<PyExpression> {
278    let l = parse_orbital_angular_momentum(l)?.value() as usize;
279    if barrier_factors {
280        Ok(PyExpression(BreitWigner::new(
281            py_tags(tags)?,
282            mass.0,
283            width.0,
284            l,
285            &daughter_1_mass.0,
286            &daughter_2_mass.0,
287            &resonance_mass.0,
288        )?))
289    } else {
290        Ok(PyExpression(BreitWigner::new_without_barrier_factors(
291            py_tags(tags)?,
292            mass.0,
293            width.0,
294            l,
295            &daughter_1_mass.0,
296            &daughter_2_mass.0,
297            &resonance_mass.0,
298        )?))
299    }
300}
301
302/// Construct a non-relativistic Breit-Wigner amplitude.
303#[pyfunction(name = "BreitWignerNonRelativistic", signature = (*tags, mass, width, resonance_mass))]
304pub fn py_breit_wigner_non_relativistic(
305    tags: &Bound<'_, PyTuple>,
306    mass: PyParameter,
307    width: PyParameter,
308    resonance_mass: &PyMass,
309) -> PyResult<PyExpression> {
310    Ok(PyExpression(BreitWignerNonRelativistic::new(
311        py_tags(tags)?,
312        mass.0,
313        width.0,
314        &resonance_mass.0,
315    )?))
316}
317
318/// Construct a Flatte amplitude.
319#[pyfunction(name = "Flatte", signature = (*tags, mass, observed_channel_coupling, alternate_channel_coupling, observed_channel_daughter_masses, alternate_channel_daughter_masses, resonance_mass))]
320pub fn py_flatte(
321    tags: &Bound<'_, PyTuple>,
322    mass: PyParameter,
323    observed_channel_coupling: PyParameter,
324    alternate_channel_coupling: PyParameter,
325    observed_channel_daughter_masses: (PyMass, PyMass),
326    alternate_channel_daughter_masses: (f64, f64),
327    resonance_mass: &PyMass,
328) -> PyResult<PyExpression> {
329    Ok(PyExpression(Flatte::new(
330        py_tags(tags)?,
331        mass.0,
332        observed_channel_coupling.0,
333        alternate_channel_coupling.0,
334        (
335            &observed_channel_daughter_masses.0 .0,
336            &observed_channel_daughter_masses.1 .0,
337        ),
338        alternate_channel_daughter_masses,
339        &resonance_mass.0,
340    )?))
341}
342
343/// Construct a Voigt amplitude.
344#[pyfunction(name = "Voigt", signature = (*tags, mass, width, sigma, resonance_mass))]
345pub fn py_voigt(
346    tags: &Bound<'_, PyTuple>,
347    mass: PyParameter,
348    width: PyParameter,
349    sigma: PyParameter,
350    resonance_mass: &PyMass,
351) -> PyResult<PyExpression> {
352    Ok(PyExpression(Voigt::new(
353        py_tags(tags)?,
354        mass.0,
355        width.0,
356        sigma.0,
357        &resonance_mass.0,
358    )?))
359}
360
361/// Construct a spherical-harmonic amplitude.
362#[pyfunction(name = "Ylm", signature = (*tags, l, m, angles))]
363pub fn py_ylm(
364    tags: &Bound<'_, PyTuple>,
365    l: usize,
366    m: isize,
367    angles: &PyAngles,
368) -> PyResult<PyExpression> {
369    Ok(PyExpression(Ylm::new(py_tags(tags)?, l, m, &angles.0)?))
370}
371
372/// Construct a polarized spherical-harmonic amplitude.
373#[pyfunction(name = "Zlm", signature = (*tags, l, m, r, angles, polarization))]
374pub fn py_zlm(
375    tags: &Bound<'_, PyTuple>,
376    l: usize,
377    m: isize,
378    r: &str,
379    angles: &PyAngles,
380    polarization: &PyPolarization,
381) -> PyResult<PyExpression> {
382    Ok(PyExpression(Zlm::new(
383        py_tags(tags)?,
384        l,
385        m,
386        r.parse()?,
387        &angles.0,
388        &polarization.0,
389    )?))
390}
391
392/// Construct a polarization phase amplitude.
393#[pyfunction(name = "PolPhase", signature = (*tags, polarization))]
394pub fn py_polphase(
395    tags: &Bound<'_, PyTuple>,
396    polarization: &PyPolarization,
397) -> PyResult<PyExpression> {
398    Ok(PyExpression(PolPhase::new(
399        py_tags(tags)?,
400        &polarization.0,
401    )?))
402}
403
404/// Construct a Wigner-D amplitude.
405#[pyfunction(name = "WignerD", signature = (*tags, spin, row_projection, column_projection, angles))]
406pub fn py_wigner_d(
407    tags: &Bound<'_, PyTuple>,
408    spin: &Bound<'_, PyAny>,
409    row_projection: &Bound<'_, PyAny>,
410    column_projection: &Bound<'_, PyAny>,
411    angles: &PyAngles,
412) -> PyResult<PyExpression> {
413    Ok(PyExpression(WignerD::new(
414        py_tags(tags)?,
415        parse_angular_momentum(spin)?,
416        parse_projection(row_projection)?,
417        parse_projection(column_projection)?,
418        &angles.0,
419    )?))
420}
421
422/// Construct a Blatt-Weisskopf amplitude.
423#[pyfunction(name = "BlattWeisskopf", signature = (*tags, decay, l, reference_mass, q_r = QR_DEFAULT, sheet = "physical", kind = "full"))]
424pub fn py_blatt_weisskopf(
425    tags: &Bound<'_, PyTuple>,
426    decay: &PyDecay,
427    l: &Bound<'_, PyAny>,
428    reference_mass: f64,
429    q_r: f64,
430    sheet: &str,
431    kind: &str,
432) -> PyResult<PyExpression> {
433    let sheet = match sheet.to_ascii_lowercase().as_str() {
434        "physical" => Sheet::Physical,
435        "unphysical" => Sheet::Unphysical,
436        _ => {
437            return Err(PyValueError::new_err(
438                "sheet must be 'physical' or 'unphysical'",
439            ));
440        }
441    };
442    let kind = match kind.to_ascii_lowercase().as_str() {
443        "full" => BarrierKind::Full,
444        "tensor" => BarrierKind::Tensor,
445        _ => {
446            return Err(PyValueError::new_err("kind must be 'full' or 'tensor'"));
447        }
448    };
449    Ok(PyExpression(BlattWeisskopf::new(
450        py_tags(tags)?,
451        &decay.0,
452        parse_orbital_angular_momentum(l)?,
453        reference_mass,
454        q_r,
455        sheet,
456        kind,
457    )?))
458}
459
460/// Construct a Clebsch-Gordan constant expression.
461#[pyfunction(name = "ClebschGordan", signature = (*tags, j1, m1, j2, m2, j, m))]
462pub fn py_clebsch_gordan(
463    tags: &Bound<'_, PyTuple>,
464    j1: &Bound<'_, PyAny>,
465    m1: &Bound<'_, PyAny>,
466    j2: &Bound<'_, PyAny>,
467    m2: &Bound<'_, PyAny>,
468    j: &Bound<'_, PyAny>,
469    m: &Bound<'_, PyAny>,
470) -> PyResult<PyExpression> {
471    Ok(PyExpression(ClebschGordan::new(
472        py_tags(tags)?,
473        parse_angular_momentum(j1)?,
474        parse_projection(m1)?,
475        parse_angular_momentum(j2)?,
476        parse_projection(m2)?,
477        parse_angular_momentum(j)?,
478        parse_projection(m)?,
479    )?))
480}
481
482/// Construct a Wigner-3j constant expression.
483#[pyfunction(name = "Wigner3j", signature = (*tags, j1, m1, j2, m2, j3, m3))]
484pub fn py_wigner_3j(
485    tags: &Bound<'_, PyTuple>,
486    j1: &Bound<'_, PyAny>,
487    m1: &Bound<'_, PyAny>,
488    j2: &Bound<'_, PyAny>,
489    m2: &Bound<'_, PyAny>,
490    j3: &Bound<'_, PyAny>,
491    m3: &Bound<'_, PyAny>,
492) -> PyResult<PyExpression> {
493    Ok(PyExpression(Wigner3j::new(
494        py_tags(tags)?,
495        parse_angular_momentum(j1)?,
496        parse_projection(m1)?,
497        parse_angular_momentum(j2)?,
498        parse_projection(m2)?,
499        parse_angular_momentum(j3)?,
500        parse_projection(m3)?,
501    )?))
502}
503
504/// Construct a photon SDME amplitude.
505#[pyfunction(name = "PhotonSDME", signature = (*tags, helicity, helicity_prime, polarization = None))]
506pub fn py_photon_sdme(
507    tags: &Bound<'_, PyTuple>,
508    helicity: i32,
509    helicity_prime: i32,
510    polarization: Option<&PyPolarization>,
511) -> PyResult<PyExpression> {
512    let polarization = polarization
513        .map(|polarization| PhotonPolarization::Linear(Box::new(polarization.0.clone())))
514        .unwrap_or(PhotonPolarization::Unpolarized);
515    Ok(PyExpression(PhotonSDME::new(
516        py_tags(tags)?,
517        polarization,
518        PhotonHelicity::new(helicity)?,
519        PhotonHelicity::new(helicity_prime)?,
520    )?))
521}
522
523/// Construct a phase-space factor amplitude.
524#[pyfunction(name = "PhaseSpaceFactor", signature = (*tags, recoil_mass, daughter_1_mass, daughter_2_mass, resonance_mass, mandelstam_s))]
525pub fn py_phase_space_factor(
526    tags: &Bound<'_, PyTuple>,
527    recoil_mass: &PyMass,
528    daughter_1_mass: &PyMass,
529    daughter_2_mass: &PyMass,
530    resonance_mass: &PyMass,
531    mandelstam_s: &PyMandelstam,
532) -> PyResult<PyExpression> {
533    Ok(PyExpression(PhaseSpaceFactor::new(
534        py_tags(tags)?,
535        &recoil_mass.0,
536        &daughter_1_mass.0,
537        &daughter_2_mass.0,
538        &resonance_mass.0,
539        &mandelstam_s.0,
540    )?))
541}
542
543fn py_lookup_inputs(
544    variables: Vec<PyVariable>,
545    axis_coordinates: Vec<Vec<f64>>,
546) -> LadduResult<LookupInputs> {
547    let axis_coordinates = axis_coordinates
548        .into_iter()
549        .map(LookupAxis::new)
550        .collect::<LadduResult<Vec<_>>>()?;
551    let variables = variables
552        .into_iter()
553        .map(|variable| Box::new(variable) as Box<dyn Variable>)
554        .collect();
555    Ok((variables, axis_coordinates))
556}
557
558/// Construct a fixed-complex lookup table amplitude.
559#[pyfunction(name = "LookupTable", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
560pub fn py_lookup_table(
561    tags: &Bound<'_, PyTuple>,
562    variables: Vec<PyVariable>,
563    axis_coordinates: Vec<Vec<f64>>,
564    values: Vec<Complex64>,
565    interpolation: &str,
566    boundary_mode: &str,
567) -> PyResult<PyExpression> {
568    let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
569    Ok(PyExpression(LookupTable::new(
570        py_tags(tags)?,
571        variables,
572        axis_coordinates,
573        values,
574        interpolation.parse()?,
575        boundary_mode.parse()?,
576    )?))
577}
578
579/// Construct a scalar-parameter lookup table amplitude.
580#[pyfunction(name = "LookupTableScalar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
581pub fn py_lookup_table_scalar(
582    tags: &Bound<'_, PyTuple>,
583    variables: Vec<PyVariable>,
584    axis_coordinates: Vec<Vec<f64>>,
585    values: Vec<PyParameter>,
586    interpolation: &str,
587    boundary_mode: &str,
588) -> PyResult<PyExpression> {
589    let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
590    Ok(PyExpression(LookupTable::new_scalar(
591        py_tags(tags)?,
592        variables,
593        axis_coordinates,
594        values.into_iter().map(|value| value.0).collect(),
595        interpolation.parse()?,
596        boundary_mode.parse()?,
597    )?))
598}
599
600/// Construct a cartesian-complex lookup table amplitude.
601#[pyfunction(name = "LookupTableComplex", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
602pub fn py_lookup_table_complex(
603    tags: &Bound<'_, PyTuple>,
604    variables: Vec<PyVariable>,
605    axis_coordinates: Vec<Vec<f64>>,
606    values: Vec<(PyParameter, PyParameter)>,
607    interpolation: &str,
608    boundary_mode: &str,
609) -> PyResult<PyExpression> {
610    let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
611    Ok(PyExpression(LookupTable::new_cartesian_complex(
612        py_tags(tags)?,
613        variables,
614        axis_coordinates,
615        values
616            .into_iter()
617            .map(|(value_re, value_im)| (value_re.0, value_im.0))
618            .collect(),
619        interpolation.parse()?,
620        boundary_mode.parse()?,
621    )?))
622}
623
624/// Construct a polar-complex lookup table amplitude.
625#[pyfunction(name = "LookupTablePolar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
626pub fn py_lookup_table_polar(
627    tags: &Bound<'_, PyTuple>,
628    variables: Vec<PyVariable>,
629    axis_coordinates: Vec<Vec<f64>>,
630    values: Vec<(PyParameter, PyParameter)>,
631    interpolation: &str,
632    boundary_mode: &str,
633) -> PyResult<PyExpression> {
634    let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
635    Ok(PyExpression(LookupTable::new_polar_complex(
636        py_tags(tags)?,
637        variables,
638        axis_coordinates,
639        values
640            .into_iter()
641            .map(|(value_r, value_theta)| (value_r.0, value_theta.0))
642            .collect(),
643        interpolation.parse()?,
644        boundary_mode.parse()?,
645    )?))
646}
647
648/// Construct the fixed Kopf `a0` K-matrix amplitude.
649#[pyfunction(name = "KopfKMatrixA0", signature = (*tags, couplings, channel, mass, seed = None))]
650pub fn py_kopf_kmatrix_a0(
651    tags: &Bound<'_, PyTuple>,
652    couplings: [[PyParameter; 2]; 2],
653    channel: PyKopfKMatrixA0Channel,
654    mass: PyMass,
655    seed: Option<usize>,
656) -> PyResult<PyExpression> {
657    Ok(PyExpression(KopfKMatrixA0::new(
658        py_tags(tags)?,
659        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
660        channel.into(),
661        &mass.0,
662        seed,
663    )?))
664}
665
666/// Construct the fixed Kopf `a2` K-matrix amplitude.
667#[pyfunction(name = "KopfKMatrixA2", signature = (*tags, couplings, channel, mass, seed = None))]
668pub fn py_kopf_kmatrix_a2(
669    tags: &Bound<'_, PyTuple>,
670    couplings: [[PyParameter; 2]; 2],
671    channel: PyKopfKMatrixA2Channel,
672    mass: PyMass,
673    seed: Option<usize>,
674) -> PyResult<PyExpression> {
675    Ok(PyExpression(KopfKMatrixA2::new(
676        py_tags(tags)?,
677        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
678        channel.into(),
679        &mass.0,
680        seed,
681    )?))
682}
683
684/// Construct the fixed Kopf `f0` K-matrix amplitude.
685#[pyfunction(name = "KopfKMatrixF0", signature = (*tags, couplings, channel, mass, seed = None))]
686pub fn py_kopf_kmatrix_f0(
687    tags: &Bound<'_, PyTuple>,
688    couplings: [[PyParameter; 2]; 5],
689    channel: PyKopfKMatrixF0Channel,
690    mass: PyMass,
691    seed: Option<usize>,
692) -> PyResult<PyExpression> {
693    Ok(PyExpression(KopfKMatrixF0::new(
694        py_tags(tags)?,
695        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
696        channel.into(),
697        &mass.0,
698        seed,
699    )?))
700}
701
702/// Construct the fixed Kopf `f2` K-matrix amplitude.
703#[pyfunction(name = "KopfKMatrixF2", signature = (*tags, couplings, channel, mass, seed = None))]
704pub fn py_kopf_kmatrix_f2(
705    tags: &Bound<'_, PyTuple>,
706    couplings: [[PyParameter; 2]; 4],
707    channel: PyKopfKMatrixF2Channel,
708    mass: PyMass,
709    seed: Option<usize>,
710) -> PyResult<PyExpression> {
711    Ok(PyExpression(KopfKMatrixF2::new(
712        py_tags(tags)?,
713        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
714        channel.into(),
715        &mass.0,
716        seed,
717    )?))
718}
719
720/// Construct the fixed Kopf `pi1` K-matrix amplitude.
721#[pyfunction(name = "KopfKMatrixPi1", signature = (*tags, couplings, channel, mass))]
722pub fn py_kopf_kmatrix_pi1(
723    tags: &Bound<'_, PyTuple>,
724    couplings: [[PyParameter; 2]; 1],
725    channel: PyKopfKMatrixPi1Channel,
726    mass: PyMass,
727) -> PyResult<PyExpression> {
728    Ok(PyExpression(KopfKMatrixPi1::new(
729        py_tags(tags)?,
730        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
731        channel.into(),
732        &mass.0,
733    )?))
734}
735
736/// Construct the fixed Kopf `rho` K-matrix amplitude.
737#[pyfunction(name = "KopfKMatrixRho", signature = (*tags, couplings, channel, mass))]
738pub fn py_kopf_kmatrix_rho(
739    tags: &Bound<'_, PyTuple>,
740    couplings: [[PyParameter; 2]; 2],
741    channel: PyKopfKMatrixRhoChannel,
742    mass: PyMass,
743) -> PyResult<PyExpression> {
744    Ok(PyExpression(KopfKMatrixRho::new(
745        py_tags(tags)?,
746        array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
747        channel.into(),
748        &mass.0,
749    )?))
750}
751
752#[pymethods]
753impl PyExpression {
754    /// The parameters used by the Expression.
755    ///
756    /// Returns
757    /// -------
758    /// parameters : ParameterMap
759    ///     The parameter map.
760    #[getter]
761    fn parameters(&self) -> PyParameterMap {
762        PyParameterMap(self.0.parameters())
763    }
764    /// Number of free parameters
765    #[getter]
766    fn n_free(&self) -> usize {
767        self.0.n_free()
768    }
769    /// Number of fixed parameters
770    #[getter]
771    fn n_fixed(&self) -> usize {
772        self.0.n_fixed()
773    }
774    /// Total number of parameters
775    #[getter]
776    fn n_parameters(&self) -> usize {
777        self.0.n_parameters()
778    }
779    /// Load an Expression by precalculating each term over the given Dataset
780    ///
781    /// Parameters
782    /// ----------
783    /// dataset : Dataset
784    ///     The Dataset to use in precalculation
785    ///
786    /// Returns
787    /// -------
788    /// Evaluator
789    ///     An object that can be used to evaluate the `expression` over each event in the
790    ///     `dataset`
791    fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
792        Ok(PyEvaluator(self.0.load(&dataset.0)?))
793    }
794    /// The real part of a complex Expression
795    fn real(&self) -> PyExpression {
796        PyExpression(self.0.real())
797    }
798    /// The imaginary part of a complex Expression
799    fn imag(&self) -> PyExpression {
800        PyExpression(self.0.imag())
801    }
802    /// The complex conjugate of a complex Expression
803    fn conj(&self) -> PyExpression {
804        PyExpression(self.0.conj())
805    }
806    /// The norm-squared of a complex Expression
807    fn norm_sqr(&self) -> PyExpression {
808        PyExpression(self.0.norm_sqr())
809    }
810    /// The square root of an Expression
811    fn sqrt(&self) -> PyExpression {
812        PyExpression(self.0.sqrt())
813    }
814    /// Raise an Expression to an int, float, or Expression power
815    fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
816        if let Ok(value) = power.extract::<i32>() {
817            Ok(PyExpression(self.0.powi(value)))
818        } else if let Ok(value) = power.extract::<f64>() {
819            Ok(PyExpression(self.0.powf(value)))
820        } else if let Ok(expression) = power.extract::<PyExpression>() {
821            Ok(PyExpression(self.0.pow(&expression.0)))
822        } else {
823            Err(PyTypeError::new_err(
824                "power must be an int, float, or Expression",
825            ))
826        }
827    }
828    /// The exponential of an Expression
829    fn exp(&self) -> PyExpression {
830        PyExpression(self.0.exp())
831    }
832    /// The sine of an Expression
833    fn sin(&self) -> PyExpression {
834        PyExpression(self.0.sin())
835    }
836    /// The cosine of an Expression
837    fn cos(&self) -> PyExpression {
838        PyExpression(self.0.cos())
839    }
840    /// The natural logarithm of an Expression
841    fn log(&self) -> PyExpression {
842        PyExpression(self.0.log())
843    }
844    /// The complex phase factor exp(i * expression)
845    fn cis(&self) -> PyExpression {
846        PyExpression(self.0.cis())
847    }
848    /// Fix a parameter used by this Expression.
849    fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
850        Ok(self.0.fix_parameter(name, value)?)
851    }
852    /// Mark a parameter used by this Expression as free.
853    fn free_parameter(&self, name: &str) -> PyResult<()> {
854        Ok(self.0.free_parameter(name)?)
855    }
856    /// Rename a single parameter used by this Expression.
857    fn rename_parameter(&mut self, old: &str, new: &str) -> PyResult<()> {
858        Ok(self.0.rename_parameter(old, new)?)
859    }
860    /// Rename several parameters used by this Expression.
861    fn rename_parameters(&mut self, mapping: HashMap<String, String>) -> PyResult<()> {
862        Ok(self.0.rename_parameters(&mapping)?)
863    }
864    /// Return a tree-like diagnostic view of the compiled Expression.
865    #[getter]
866    fn compiled_expression(&self) -> PyCompiledExpression {
867        PyCompiledExpression(self.0.compiled_expression())
868    }
869    fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
870        if let Ok(other_expr) = other.extract::<PyExpression>() {
871            Ok(PyExpression(self.0.clone() + other_expr.0))
872        } else {
873            Err(PyTypeError::new_err("Unsupported operand type for +"))
874        }
875    }
876    fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
877        if let Ok(other_expr) = other.extract::<PyExpression>() {
878            Ok(PyExpression(other_expr.0 + self.0.clone()))
879        } else {
880            Err(PyTypeError::new_err("Unsupported operand type for +"))
881        }
882    }
883    fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
884        if let Ok(other_expr) = other.extract::<PyExpression>() {
885            Ok(PyExpression(self.0.clone() - other_expr.0))
886        } else {
887            Err(PyTypeError::new_err("Unsupported operand type for -"))
888        }
889    }
890    fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
891        if let Ok(other_expr) = other.extract::<PyExpression>() {
892            Ok(PyExpression(other_expr.0 - self.0.clone()))
893        } else {
894            Err(PyTypeError::new_err("Unsupported operand type for -"))
895        }
896    }
897    fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
898        if let Ok(other_expr) = other.extract::<PyExpression>() {
899            Ok(PyExpression(self.0.clone() * other_expr.0))
900        } else {
901            Err(PyTypeError::new_err("Unsupported operand type for *"))
902        }
903    }
904    fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
905        if let Ok(other_expr) = other.extract::<PyExpression>() {
906            Ok(PyExpression(other_expr.0 * self.0.clone()))
907        } else {
908            Err(PyTypeError::new_err("Unsupported operand type for *"))
909        }
910    }
911    fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
912        if let Ok(other_expr) = other.extract::<PyExpression>() {
913            Ok(PyExpression(self.0.clone() / other_expr.0))
914        } else {
915            Err(PyTypeError::new_err("Unsupported operand type for /"))
916        }
917    }
918    fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
919        if let Ok(other_expr) = other.extract::<PyExpression>() {
920            Ok(PyExpression(other_expr.0 / self.0.clone()))
921        } else {
922            Err(PyTypeError::new_err("Unsupported operand type for /"))
923        }
924    }
925    fn __neg__(&self) -> PyExpression {
926        PyExpression(-self.0.clone())
927    }
928    fn __str__(&self) -> String {
929        format!("{}", self.0)
930    }
931    fn __repr__(&self) -> String {
932        format!("{:?}", self.0)
933    }
934
935    #[new]
936    fn new() -> Self {
937        Self(Expression::default())
938    }
939    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
940        Ok(PyBytes::new(
941            py,
942            serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
943                .map_err(LadduError::PickleError)?
944                .as_slice(),
945        ))
946    }
947    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
948        *self = Self(
949            serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
950                .map_err(LadduError::PickleError)?,
951        );
952        Ok(())
953    }
954}
955
956/// A class which can be used to evaluate a stored Expression
957///
958/// See Also
959/// --------
960/// laddu.Expression.load
961///
962#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
963#[derive(Clone)]
964pub struct PyEvaluator(pub Evaluator);
965
966#[pymethods]
967impl PyEvaluator {
968    /// The free parameters used by the Evaluator
969    ///
970    /// Returns
971    /// -------
972    /// parameters : ParameterMap
973    ///     The parameter map.
974    ///
975    #[getter]
976    fn parameters(&self) -> PyParameterMap {
977        PyParameterMap(self.0.parameters())
978    }
979    /// Number of free parameters
980    #[getter]
981    fn n_free(&self) -> usize {
982        self.0.n_free()
983    }
984    /// Number of fixed parameters
985    #[getter]
986    fn n_fixed(&self) -> usize {
987        self.0.n_fixed()
988    }
989    /// Total number of parameters
990    #[getter]
991    fn n_parameters(&self) -> usize {
992        self.0.n_parameters()
993    }
994    /// Fix a parameter used by this Evaluator.
995    fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
996        Ok(self.0.fix_parameter(name, value)?)
997    }
998    /// Mark a parameter used by this Evaluator as free.
999    fn free_parameter(&self, name: &str) -> PyResult<()> {
1000        Ok(self.0.free_parameter(name)?)
1001    }
1002    /// Rename a single parameter used by this Evaluator.
1003    fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
1004        Ok(self.0.rename_parameter(old, new)?)
1005    }
1006    /// Rename several parameters used by this Evaluator.
1007    fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
1008        Ok(self.0.rename_parameters(&mapping)?)
1009    }
1010    /// Activates Amplitude use-sites in the Expression by tag or glob selector.
1011    ///
1012    /// Parameters
1013    /// ----------
1014    /// arg : str or list of str
1015    ///     Tags or ``*``/``?`` glob selectors of Amplitudes to be activated
1016    ///
1017    /// Raises
1018    /// ------
1019    /// TypeError
1020    ///     If `arg` is not a str or list of str
1021    /// ValueError
1022    ///     If `arg` or any items of `arg` match no tagged Amplitudes
1023    /// strict : bool, default=True
1024    ///     When ``True``, raise an error if any selector matches no amplitudes. When
1025    ///     ``False``, silently skip selectors with no matches.
1026    #[pyo3(signature = (arg, *, strict=true))]
1027    fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1028        if let Ok(string_arg) = arg.extract::<String>() {
1029            if strict {
1030                self.0.activate_strict(&string_arg)?;
1031            } else {
1032                self.0.activate(&string_arg);
1033            }
1034        } else if let Ok(list_arg) = arg.cast::<PyList>() {
1035            let vec: Vec<String> = list_arg.extract()?;
1036            if strict {
1037                self.0.activate_many_strict(&vec)?;
1038            } else {
1039                self.0.activate_many(&vec);
1040            }
1041        } else {
1042            return Err(PyTypeError::new_err(
1043                "Argument must be either a string or a list of strings",
1044            ));
1045        }
1046        Ok(())
1047    }
1048    /// Activates all Amplitude use-sites in the Expression.
1049    ///
1050    fn activate_all(&self) {
1051        self.0.activate_all();
1052    }
1053    /// Deactivates Amplitude use-sites in the Expression by tag or glob selector.
1054    ///
1055    /// Deactivated Amplitudes act as zeros in the Expression
1056    ///
1057    /// Parameters
1058    /// ----------
1059    /// arg : str or list of str
1060    ///     Tags or ``*``/``?`` glob selectors of Amplitudes to be deactivated
1061    ///
1062    /// Raises
1063    /// ------
1064    /// TypeError
1065    ///     If `arg` is not a str or list of str
1066    /// ValueError
1067    ///     If `arg` or any items of `arg` match no tagged Amplitudes
1068    /// strict : bool, default=True
1069    ///     When ``True``, raise an error if any selector matches no amplitudes. When
1070    ///     ``False``, silently skip selectors with no matches.
1071    #[pyo3(signature = (arg, *, strict=true))]
1072    fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1073        if let Ok(string_arg) = arg.extract::<String>() {
1074            if strict {
1075                self.0.deactivate_strict(&string_arg)?;
1076            } else {
1077                self.0.deactivate(&string_arg);
1078            }
1079        } else if let Ok(list_arg) = arg.cast::<PyList>() {
1080            let vec: Vec<String> = list_arg.extract()?;
1081            if strict {
1082                self.0.deactivate_many_strict(&vec)?;
1083            } else {
1084                self.0.deactivate_many(&vec);
1085            }
1086        } else {
1087            return Err(PyTypeError::new_err(
1088                "Argument must be either a string or a list of strings",
1089            ));
1090        }
1091        Ok(())
1092    }
1093    /// Deactivates all tagged Amplitude use-sites in the Expression.
1094    ///
1095    fn deactivate_all(&self) {
1096        self.0.deactivate_all();
1097    }
1098    /// Isolates Amplitude use-sites in the Expression by tag or glob selector.
1099    ///
1100    /// Activates the tagged Amplitudes given in `arg` and deactivates the rest.
1101    ///
1102    /// Parameters
1103    /// ----------
1104    /// arg : str or list of str
1105    ///     Tags or ``*``/``?`` glob selectors of Amplitudes to be isolated
1106    ///
1107    /// Raises
1108    /// ------
1109    /// TypeError
1110    ///     If `arg` is not a str or list of str
1111    /// ValueError
1112    ///     If `arg` or any items of `arg` match no tagged Amplitudes
1113    /// strict : bool, default=True
1114    ///     When ``True``, raise an error if any selector matches no amplitudes. When
1115    ///     ``False``, silently skip selectors with no matches.
1116    #[pyo3(signature = (arg, *, strict=true))]
1117    fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1118        if let Ok(string_arg) = arg.extract::<String>() {
1119            if strict {
1120                self.0.isolate_strict(&string_arg)?;
1121            } else {
1122                self.0.isolate(&string_arg);
1123            }
1124        } else if let Ok(list_arg) = arg.cast::<PyList>() {
1125            let vec: Vec<String> = list_arg.extract()?;
1126            if strict {
1127                self.0.isolate_many_strict(&vec)?;
1128            } else {
1129                self.0.isolate_many(&vec);
1130            }
1131        } else {
1132            return Err(PyTypeError::new_err(
1133                "Argument must be either a string or a list of strings",
1134            ));
1135        }
1136        Ok(())
1137    }
1138
1139    /// Return the current active-amplitude mask.
1140    #[getter]
1141    fn active_mask(&self) -> Vec<bool> {
1142        self.0.active_mask()
1143    }
1144
1145    /// Apply an active-amplitude mask.
1146    fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
1147        self.0.set_active_mask(&mask)?;
1148        Ok(())
1149    }
1150
1151    /// Return a tree-like diagnostic view of the compiled Expression.
1152    #[getter]
1153    fn compiled_expression(&self) -> PyCompiledExpression {
1154        PyCompiledExpression(self.0.compiled_expression())
1155    }
1156
1157    /// Return the Expression represented by this Evaluator.
1158    #[getter]
1159    fn expression(&self) -> PyExpression {
1160        PyExpression(self.0.expression())
1161    }
1162
1163    /// Evaluate the stored Expression over the stored Dataset
1164    ///
1165    /// Parameters
1166    /// ----------
1167    /// parameters : list of float
1168    ///     The values to use for the free parameters
1169    /// threads : int, optional
1170    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
1171    ///     global or context-managed default; any positive value overrides that default for
1172    ///     this call only)
1173    ///
1174    /// Returns
1175    /// -------
1176    /// result : array_like
1177    ///     A ``numpy`` array of complex values for each Event in the Dataset
1178    ///
1179    /// Raises
1180    /// ------
1181    /// Exception
1182    ///     If there was an error building the thread pool
1183    ///
1184    #[pyo3(signature = (parameters, *, threads=None))]
1185    fn evaluate<'py>(
1186        &self,
1187        py: Python<'py>,
1188        parameters: Vec<f64>,
1189        threads: Option<usize>,
1190    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
1191        let values = install_with_threads(threads, || self.0.evaluate(&parameters))?;
1192        Ok(PyArray1::from_slice(py, &values?))
1193    }
1194    /// Evaluate the stored Expression over a subset of the stored Dataset
1195    ///
1196    /// Parameters
1197    /// ----------
1198    /// parameters : list of float
1199    ///     The values to use for the free parameters
1200    /// indices : list of int
1201    ///     The indices of events to evaluate
1202    /// threads : int, optional
1203    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
1204    ///     global or context-managed default; any positive value overrides that default for
1205    ///     this call only)
1206    ///
1207    /// Returns
1208    /// -------
1209    /// result : array_like
1210    ///     A ``numpy`` array of complex values for each indexed Event in the Dataset
1211    ///
1212    /// Raises
1213    /// ------
1214    /// Exception
1215    ///     If there was an error building the thread pool
1216    ///
1217    #[pyo3(signature = (parameters, indices, *, threads=None))]
1218    fn evaluate_batch<'py>(
1219        &self,
1220        py: Python<'py>,
1221        parameters: Vec<f64>,
1222        indices: Vec<usize>,
1223        threads: Option<usize>,
1224    ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
1225        let values =
1226            install_with_threads(threads, || self.0.evaluate_batch(&parameters, &indices))?;
1227        Ok(PyArray1::from_slice(py, &values?))
1228    }
1229    /// Evaluate the gradient of the stored Expression over the stored Dataset
1230    ///
1231    /// Parameters
1232    /// ----------
1233    /// parameters : list of float
1234    ///     The values to use for the free parameters
1235    /// threads : int, optional
1236    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
1237    ///     global or context-managed default; any positive value overrides that default for
1238    ///     this call only)
1239    ///
1240    /// Returns
1241    /// -------
1242    /// result : array_like
1243    ///     A ``numpy`` 2D array of complex values for each Event in the Dataset
1244    ///
1245    /// Raises
1246    /// ------
1247    /// Exception
1248    ///     If there was an error building the thread pool or problem creating the resulting
1249    ///     ``numpy`` array
1250    ///
1251    #[pyo3(signature = (parameters, *, threads=None))]
1252    fn evaluate_gradient<'py>(
1253        &self,
1254        py: Python<'py>,
1255        parameters: Vec<f64>,
1256        threads: Option<usize>,
1257    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
1258        let gradients: LadduResult<_> = install_with_threads(threads, || {
1259            Ok(self
1260                .0
1261                .evaluate_gradient(&parameters)?
1262                .iter()
1263                .map(|grad| grad.data.as_vec().to_vec())
1264                .collect::<Vec<Vec<Complex64>>>())
1265        })?;
1266        Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
1267    }
1268    /// Evaluate the gradient of the stored Expression over a subset of the stored Dataset
1269    ///
1270    /// Parameters
1271    /// ----------
1272    /// parameters : list of float
1273    ///     The values to use for the free parameters
1274    /// indices : list of int
1275    ///     The indices of events to evaluate
1276    /// threads : int, optional
1277    ///     The number of threads to use (setting this to ``None`` or ``0`` uses the current
1278    ///     global or context-managed default; any positive value overrides that default for
1279    ///     this call only)
1280    ///
1281    /// Returns
1282    /// -------
1283    /// result : array_like
1284    ///     A ``numpy`` 2D array of complex values for each indexed Event in the Dataset
1285    ///
1286    /// Raises
1287    /// ------
1288    /// Exception
1289    ///     If there was an error building the thread pool or problem creating the resulting
1290    ///     ``numpy`` array
1291    ///
1292    #[pyo3(signature = (parameters, indices, *, threads=None))]
1293    fn evaluate_gradient_batch<'py>(
1294        &self,
1295        py: Python<'py>,
1296        parameters: Vec<f64>,
1297        indices: Vec<usize>,
1298        threads: Option<usize>,
1299    ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
1300        let gradients: LadduResult<_> = install_with_threads(threads, || {
1301            Ok(self
1302                .0
1303                .evaluate_gradient_batch(&parameters, &indices)?
1304                .iter()
1305                .map(|grad| grad.data.as_vec().to_vec())
1306                .collect::<Vec<Vec<Complex64>>>())
1307        })?;
1308        Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
1309    }
1310}
1311
1312/// A class which can be used to display the compiled form of an Expression
1313///
1314/// Notes
1315/// -----
1316/// This should not be used for anything other than diagnostic purposes.
1317///
1318#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
1319#[derive(Clone)]
1320pub struct PyCompiledExpression(pub CompiledExpression);
1321
1322#[pymethods]
1323impl PyCompiledExpression {
1324    fn __str__(&self) -> String {
1325        format!("{}", self.0)
1326    }
1327    fn __repr__(&self) -> String {
1328        format!("{:?}", self.0)
1329    }
1330}
1331
1332#[pyclass(name = "ParameterMap", module = "laddu", from_py_object)]
1333#[derive(Clone)]
1334pub struct PyParameterMap(pub ParameterMap);
1335
1336#[pymethods]
1337impl PyParameterMap {
1338    #[getter]
1339    fn names<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
1340        PyTuple::new(py, self.0.names())
1341    }
1342
1343    #[getter]
1344    fn free(&self) -> PyParameterMap {
1345        PyParameterMap(self.0.free())
1346    }
1347
1348    #[getter]
1349    fn fixed(&self) -> PyParameterMap {
1350        PyParameterMap(self.0.fixed())
1351    }
1352
1353    fn contains(&self, name: &str) -> bool {
1354        self.0.contains_key(name)
1355    }
1356
1357    fn __contains__(&self, name: &str) -> bool {
1358        self.contains(name)
1359    }
1360
1361    fn __len__(&self) -> usize {
1362        self.0.len()
1363    }
1364
1365    fn __bool__(&self) -> bool {
1366        !self.0.is_empty()
1367    }
1368
1369    fn __getitem__(&self, index: &Bound<'_, PyAny>) -> PyResult<PyParameter> {
1370        if let Ok(name) = index.extract::<String>() {
1371            return self
1372                .0
1373                .get(&name)
1374                .cloned()
1375                .map(PyParameter)
1376                .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err(name));
1377        }
1378        if let Ok(index) = index.extract::<isize>() {
1379            let len = self.0.len() as isize;
1380            let normalized = if index < 0 { len + index } else { index };
1381            if normalized < 0 || normalized >= len {
1382                return Err(pyo3::exceptions::PyIndexError::new_err(format!(
1383                    "parameter index out of range: {index}"
1384                )));
1385            }
1386            return Ok(PyParameter(self.0[normalized as usize].clone()));
1387        }
1388        Err(pyo3::exceptions::PyTypeError::new_err(
1389            "ParameterMap indices must be str or int",
1390        ))
1391    }
1392
1393    fn __str__(&self) -> String {
1394        self.0.to_string()
1395    }
1396
1397    fn __repr__(&self) -> String {
1398        format!("{:?}", self.0)
1399    }
1400
1401    fn index(&self, name: &str) -> PyResult<usize> {
1402        self.0
1403            .index(name)
1404            .ok_or_else(|| PyValueError::new_err(format!("parameter not found: {name}")))
1405    }
1406
1407    fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
1408        let params: Vec<PyParameter> = self.0.clone().into_iter().map(PyParameter).collect();
1409        let list = PyList::new(py, params)?;
1410        PyIterator::from_object(&list)
1411    }
1412}
1413
1414#[pyclass(name = "Parameter", module = "laddu", from_py_object)]
1415#[derive(Clone)]
1416pub struct PyParameter(pub Parameter);
1417
1418#[pymethods]
1419impl PyParameter {
1420    #[getter]
1421    fn name(&self) -> String {
1422        self.0.name()
1423    }
1424    #[getter]
1425    fn fixed(&self) -> Option<f64> {
1426        self.0.fixed()
1427    }
1428    #[getter]
1429    fn initial(&self) -> Option<f64> {
1430        self.0.initial()
1431    }
1432    #[getter]
1433    fn bounds(&self) -> (Option<f64>, Option<f64>) {
1434        self.0.bounds()
1435    }
1436    #[getter]
1437    fn unit(&self) -> Option<String> {
1438        self.0.unit()
1439    }
1440    #[getter]
1441    fn latex(&self) -> Option<String> {
1442        self.0.latex()
1443    }
1444    #[getter]
1445    fn description(&self) -> Option<String> {
1446        self.0.description()
1447    }
1448}
1449
1450/// A free parameter which floats during an optimization
1451///
1452/// Parameters
1453/// ----------
1454/// name : str
1455///     The name of the free parameter
1456/// fixed : float, optional
1457///     If specified, the parameter will be fixed to this value
1458/// initial : float, optional
1459///     If specified, the parameter will always be initialized to this value
1460/// bounds : tuple of (float or None, float or None)
1461///     Specify the lower and upper bounds for the parameter (None corresponds to no bound)
1462/// unit : str, optional
1463///     Optional unit string which may be used to annotate the parameter
1464/// latex : str, optional
1465///     Optional LaTeX representation of the parameter
1466/// description : str, optional
1467///     Optional description of the parameter
1468///
1469/// Returns
1470/// -------
1471/// laddu.Parameter
1472///     An object that can be used as the input for many Amplitude constructors
1473///
1474/// Notes
1475/// -----
1476/// Two free parameters with the same name are shared in a fit.
1477///
1478/// Attempting to set both the fixed and initial value will result in an overwrite (both will be
1479/// set to the "fixed" value).
1480///
1481#[pyfunction(name = "parameter", signature = (name, fixed=None, *, initial=None, bounds=(None, None), unit=None, latex=None, description=None))]
1482pub fn py_parameter(
1483    name: &str,
1484    fixed: Option<f64>,
1485    initial: Option<f64>,
1486    bounds: (Option<f64>, Option<f64>),
1487    unit: Option<&str>,
1488    latex: Option<&str>,
1489    description: Option<&str>,
1490) -> PyParameter {
1491    let par = Parameter::new(name);
1492    if let Some(value) = initial {
1493        par.set_initial(value);
1494    }
1495    if let Some(value) = fixed {
1496        par.set_fixed_value(Some(value)); // TODO: make this all consistent
1497    }
1498    par.set_bounds(bounds.0, bounds.1);
1499    if let Some(unit) = unit {
1500        par.set_unit(unit);
1501    }
1502    if let Some(latex) = latex {
1503        par.set_latex(latex);
1504    }
1505    if let Some(description) = description {
1506        par.set_description(description);
1507    }
1508    PyParameter(par)
1509}
1510
1511/// An amplitude used only for internal testing which evaluates `(p0 + i * p1) * event.p4s\[0\].e`.
1512#[pyfunction(name = "TestAmplitude", signature = (*tags, re, im))]
1513pub fn py_test_amplitude(
1514    tags: &Bound<'_, PyTuple>,
1515    re: PyParameter,
1516    im: PyParameter,
1517) -> PyResult<PyExpression> {
1518    Ok(PyExpression(TestAmplitude::new(
1519        py_tags(tags)?,
1520        re.0,
1521        im.0,
1522    )?))
1523}