1use std::{array, collections::HashMap};
2
3use laddu_amplitudes::{
4 angular::{
5 BlattWeisskopf, ClebschGordan, PhotonHelicity, PhotonPolarization, PhotonSDME, PolPhase,
6 Wigner3j, WignerD, Ylm, Zlm,
7 },
8 kmatrix::{
9 KopfKMatrixA0, KopfKMatrixA0Channel, KopfKMatrixA2, KopfKMatrixA2Channel, KopfKMatrixF0,
10 KopfKMatrixF0Channel, KopfKMatrixF2, KopfKMatrixF2Channel, KopfKMatrixPi1,
11 KopfKMatrixPi1Channel, KopfKMatrixRho, KopfKMatrixRhoChannel,
12 },
13 lookup::{LookupAxis, LookupTable},
14 resonance::{BreitWigner, BreitWignerNonRelativistic, Flatte, PhaseSpaceFactor, Voigt},
15 scalar::{ComplexScalar, PolarComplexScalar, Scalar, VariableScalar},
16};
17use laddu_core::{
18 amplitude::{Evaluator, Expression, Parameter, ParameterMap, TestAmplitude},
19 math::{BarrierKind, Sheet, QR_DEFAULT},
20 traits::Variable,
21 CompiledExpression, LadduError, LadduResult, ThreadPoolManager,
22};
23use num::complex::Complex64;
24use numpy::{PyArray1, PyArray2};
25use pyo3::{
26 exceptions::{PyTypeError, PyValueError},
27 prelude::*,
28 types::{PyAny, PyBytes, PyIterator, PyList, PyTuple},
29};
30
31use crate::{
32 data::PyDataset,
33 quantum::angular_momentum::{
34 parse_angular_momentum, parse_orbital_angular_momentum, parse_projection,
35 },
36 variables::{PyAngles, PyDecay, PyMandelstam, PyMass, PyPolarization, PyVariable},
37};
38
39type LookupInputs = (Vec<Box<dyn Variable>>, Vec<LookupAxis>);
40
41macro_rules! py_kmatrix_channel {
42 ($py_name:ident, $python_name:literal, $rust_name:path { $($variant:ident),+ $(,)? }) => {
43 #[pyclass(eq, name = $python_name, module = "laddu", from_py_object)]
44 #[derive(Clone, PartialEq)]
45 pub enum $py_name {
46 $($variant,)+
47 }
48
49 impl From<$py_name> for $rust_name {
50 fn from(value: $py_name) -> Self {
51 match value {
52 $( $py_name::$variant => Self::$variant, )+
53 }
54 }
55 }
56 };
57}
58
59py_kmatrix_channel!(
60 PyKopfKMatrixA0Channel,
61 "KopfKMatrixA0Channel",
62 KopfKMatrixA0Channel { PiEta, KKbar }
63);
64py_kmatrix_channel!(
65 PyKopfKMatrixA2Channel,
66 "KopfKMatrixA2Channel",
67 KopfKMatrixA2Channel {
68 PiEta,
69 KKbar,
70 PiEtaPrime
71 }
72);
73py_kmatrix_channel!(
74 PyKopfKMatrixF0Channel,
75 "KopfKMatrixF0Channel",
76 KopfKMatrixF0Channel {
77 PiPi,
78 FourPi,
79 KKbar,
80 EtaEta,
81 EtaEtaPrime
82 }
83);
84py_kmatrix_channel!(
85 PyKopfKMatrixF2Channel,
86 "KopfKMatrixF2Channel",
87 KopfKMatrixF2Channel {
88 PiPi,
89 FourPi,
90 KKbar,
91 EtaEta
92 }
93);
94py_kmatrix_channel!(
95 PyKopfKMatrixPi1Channel,
96 "KopfKMatrixPi1Channel",
97 KopfKMatrixPi1Channel { PiEta, PiEtaPrime }
98);
99py_kmatrix_channel!(
100 PyKopfKMatrixRhoChannel,
101 "KopfKMatrixRhoChannel",
102 KopfKMatrixRhoChannel {
103 PiPi,
104 FourPi,
105 KKbar
106 }
107);
108
109fn install_with_threads<R: Send>(
110 threads: Option<usize>,
111 op: impl FnOnce() -> R + Send,
112) -> LadduResult<R> {
113 ThreadPoolManager::shared().install(threads, op)
114}
115
116pub(crate) fn py_tags(tags: &Bound<'_, PyTuple>) -> PyResult<Vec<String>> {
117 tags.iter()
118 .map(|tag| tag.extract::<String>())
119 .collect::<PyResult<Vec<_>>>()
120}
121
122#[pyclass(name = "Expression", module = "laddu", skip_from_py_object)]
125#[derive(Clone)]
126pub struct PyExpression(pub Expression);
127
128impl<'py> FromPyObject<'_, 'py> for PyExpression {
129 type Error = PyErr;
130
131 fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
132 if let Ok(obj) = obj.cast::<PyExpression>() {
133 Ok(obj.borrow().clone())
134 } else if let Ok(obj) = obj.extract::<f64>() {
135 Ok(Self(obj.into()))
136 } else if let Ok(obj) = obj.extract::<Complex64>() {
137 Ok(Self(obj.into()))
138 } else {
139 Err(PyTypeError::new_err("Failed to extract Expression"))
140 }
141 }
142}
143
144#[pyfunction(name = "expr_sum")]
147pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
148 if terms.is_empty() {
149 return Ok(PyExpression(Expression::zero()));
150 }
151 if terms.len() == 1 {
152 let term = &terms[0];
153 if let Ok(expression) = term.extract::<PyExpression>() {
154 return Ok(expression);
155 }
156 return Err(PyTypeError::new_err("Item is not a PyExpression"));
157 }
158 let mut iter = terms.iter();
159 let Some(first_term) = iter.next() else {
160 return Ok(PyExpression(Expression::zero()));
161 };
162 let PyExpression(mut summation) = first_term
163 .extract::<PyExpression>()
164 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
165 for term in iter {
166 let PyExpression(expr) = term
167 .extract::<PyExpression>()
168 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
169 summation = summation + expr;
170 }
171 Ok(PyExpression(summation))
172}
173
174#[pyfunction(name = "expr_product")]
177pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
178 if terms.is_empty() {
179 return Ok(PyExpression(Expression::one()));
180 }
181 if terms.len() == 1 {
182 let term = &terms[0];
183 if let Ok(expression) = term.extract::<PyExpression>() {
184 return Ok(expression);
185 }
186 return Err(PyTypeError::new_err("Item is not a PyExpression"));
187 }
188 let mut iter = terms.iter();
189 let Some(first_term) = iter.next() else {
190 return Ok(PyExpression(Expression::one()));
191 };
192 let PyExpression(mut product) = first_term
193 .extract::<PyExpression>()
194 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
195 for term in iter {
196 let PyExpression(expr) = term
197 .extract::<PyExpression>()
198 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
199 product = product * expr;
200 }
201 Ok(PyExpression(product))
202}
203
204#[pyfunction(name = "Zero")]
207pub fn py_expr_zero() -> PyExpression {
208 PyExpression(Expression::zero())
209}
210
211#[pyfunction(name = "One")]
214pub fn py_expr_one() -> PyExpression {
215 PyExpression(Expression::one())
216}
217
218#[pyfunction(name = "Scalar", signature = (*tags, value))]
220pub fn py_scalar(tags: &Bound<'_, PyTuple>, value: PyParameter) -> PyResult<PyExpression> {
221 Ok(PyExpression(Scalar::new(py_tags(tags)?, value.0)?))
222}
223
224#[pyfunction(name = "VariableScalar", signature = (*tags, variable))]
226pub fn py_variable_scalar(
227 tags: &Bound<'_, PyTuple>,
228 variable: Bound<'_, PyAny>,
229) -> PyResult<PyExpression> {
230 let variable = variable.extract::<PyVariable>()?;
231 Ok(PyExpression(VariableScalar::new(
232 py_tags(tags)?,
233 &variable,
234 )?))
235}
236
237#[pyfunction(name = "ComplexScalar", signature = (*tags, re, im))]
239pub fn py_complex_scalar(
240 tags: &Bound<'_, PyTuple>,
241 re: PyParameter,
242 im: PyParameter,
243) -> PyResult<PyExpression> {
244 Ok(PyExpression(ComplexScalar::new(
245 py_tags(tags)?,
246 re.0,
247 im.0,
248 )?))
249}
250
251#[pyfunction(name = "PolarComplexScalar", signature = (*tags, r, theta))]
253pub fn py_polar_complex_scalar(
254 tags: &Bound<'_, PyTuple>,
255 r: PyParameter,
256 theta: PyParameter,
257) -> PyResult<PyExpression> {
258 Ok(PyExpression(PolarComplexScalar::new(
259 py_tags(tags)?,
260 r.0,
261 theta.0,
262 )?))
263}
264
265#[pyfunction(name = "BreitWigner", signature = (*tags, mass, width, l, daughter_1_mass, daughter_2_mass, resonance_mass, barrier_factors=true))]
267#[allow(clippy::too_many_arguments)]
268pub fn py_breit_wigner(
269 tags: &Bound<'_, PyTuple>,
270 mass: PyParameter,
271 width: PyParameter,
272 l: usize,
273 daughter_1_mass: &PyMass,
274 daughter_2_mass: &PyMass,
275 resonance_mass: &PyMass,
276 barrier_factors: bool,
277) -> PyResult<PyExpression> {
278 if barrier_factors {
279 Ok(PyExpression(BreitWigner::new(
280 py_tags(tags)?,
281 mass.0,
282 width.0,
283 l,
284 &daughter_1_mass.0,
285 &daughter_2_mass.0,
286 &resonance_mass.0,
287 )?))
288 } else {
289 Ok(PyExpression(BreitWigner::new_without_barrier_factors(
290 py_tags(tags)?,
291 mass.0,
292 width.0,
293 l,
294 &daughter_1_mass.0,
295 &daughter_2_mass.0,
296 &resonance_mass.0,
297 )?))
298 }
299}
300
301#[pyfunction(name = "BreitWignerNonRelativistic", signature = (*tags, mass, width, resonance_mass))]
303pub fn py_breit_wigner_non_relativistic(
304 tags: &Bound<'_, PyTuple>,
305 mass: PyParameter,
306 width: PyParameter,
307 resonance_mass: &PyMass,
308) -> PyResult<PyExpression> {
309 Ok(PyExpression(BreitWignerNonRelativistic::new(
310 py_tags(tags)?,
311 mass.0,
312 width.0,
313 &resonance_mass.0,
314 )?))
315}
316
317#[pyfunction(name = "Flatte", signature = (*tags, mass, observed_channel_coupling, alternate_channel_coupling, observed_channel_daughter_masses, alternate_channel_daughter_masses, resonance_mass))]
319pub fn py_flatte(
320 tags: &Bound<'_, PyTuple>,
321 mass: PyParameter,
322 observed_channel_coupling: PyParameter,
323 alternate_channel_coupling: PyParameter,
324 observed_channel_daughter_masses: (PyMass, PyMass),
325 alternate_channel_daughter_masses: (f64, f64),
326 resonance_mass: &PyMass,
327) -> PyResult<PyExpression> {
328 Ok(PyExpression(Flatte::new(
329 py_tags(tags)?,
330 mass.0,
331 observed_channel_coupling.0,
332 alternate_channel_coupling.0,
333 (
334 &observed_channel_daughter_masses.0 .0,
335 &observed_channel_daughter_masses.1 .0,
336 ),
337 alternate_channel_daughter_masses,
338 &resonance_mass.0,
339 )?))
340}
341
342#[pyfunction(name = "Voigt", signature = (*tags, mass, width, sigma, resonance_mass))]
344pub fn py_voigt(
345 tags: &Bound<'_, PyTuple>,
346 mass: PyParameter,
347 width: PyParameter,
348 sigma: PyParameter,
349 resonance_mass: &PyMass,
350) -> PyResult<PyExpression> {
351 Ok(PyExpression(Voigt::new(
352 py_tags(tags)?,
353 mass.0,
354 width.0,
355 sigma.0,
356 &resonance_mass.0,
357 )?))
358}
359
360#[pyfunction(name = "Ylm", signature = (*tags, l, m, angles))]
362pub fn py_ylm(
363 tags: &Bound<'_, PyTuple>,
364 l: usize,
365 m: isize,
366 angles: &PyAngles,
367) -> PyResult<PyExpression> {
368 Ok(PyExpression(Ylm::new(py_tags(tags)?, l, m, &angles.0)?))
369}
370
371#[pyfunction(name = "Zlm", signature = (*tags, l, m, r, angles, polarization))]
373pub fn py_zlm(
374 tags: &Bound<'_, PyTuple>,
375 l: usize,
376 m: isize,
377 r: &str,
378 angles: &PyAngles,
379 polarization: &PyPolarization,
380) -> PyResult<PyExpression> {
381 Ok(PyExpression(Zlm::new(
382 py_tags(tags)?,
383 l,
384 m,
385 r.parse()?,
386 &angles.0,
387 &polarization.0,
388 )?))
389}
390
391#[pyfunction(name = "PolPhase", signature = (*tags, polarization))]
393pub fn py_polphase(
394 tags: &Bound<'_, PyTuple>,
395 polarization: &PyPolarization,
396) -> PyResult<PyExpression> {
397 Ok(PyExpression(PolPhase::new(
398 py_tags(tags)?,
399 &polarization.0,
400 )?))
401}
402
403#[pyfunction(name = "WignerD", signature = (*tags, spin, row_projection, column_projection, angles))]
405pub fn py_wigner_d(
406 tags: &Bound<'_, PyTuple>,
407 spin: &Bound<'_, PyAny>,
408 row_projection: &Bound<'_, PyAny>,
409 column_projection: &Bound<'_, PyAny>,
410 angles: &PyAngles,
411) -> PyResult<PyExpression> {
412 Ok(PyExpression(WignerD::new(
413 py_tags(tags)?,
414 parse_angular_momentum(spin)?,
415 parse_projection(row_projection)?,
416 parse_projection(column_projection)?,
417 &angles.0,
418 )?))
419}
420
421#[pyfunction(name = "BlattWeisskopf", signature = (*tags, decay, l, reference_mass, q_r = QR_DEFAULT, sheet = "physical", kind = "full"))]
423pub fn py_blatt_weisskopf(
424 tags: &Bound<'_, PyTuple>,
425 decay: &PyDecay,
426 l: &Bound<'_, PyAny>,
427 reference_mass: f64,
428 q_r: f64,
429 sheet: &str,
430 kind: &str,
431) -> PyResult<PyExpression> {
432 let sheet = match sheet.to_ascii_lowercase().as_str() {
433 "physical" => Sheet::Physical,
434 "unphysical" => Sheet::Unphysical,
435 _ => {
436 return Err(PyValueError::new_err(
437 "sheet must be 'physical' or 'unphysical'",
438 ));
439 }
440 };
441 let kind = match kind.to_ascii_lowercase().as_str() {
442 "full" => BarrierKind::Full,
443 "tensor" => BarrierKind::Tensor,
444 _ => {
445 return Err(PyValueError::new_err("kind must be 'full' or 'tensor'"));
446 }
447 };
448 Ok(PyExpression(BlattWeisskopf::new(
449 py_tags(tags)?,
450 &decay.0,
451 parse_orbital_angular_momentum(l)?,
452 reference_mass,
453 q_r,
454 sheet,
455 kind,
456 )?))
457}
458
459#[pyfunction(name = "ClebschGordan", signature = (*tags, j1, m1, j2, m2, j, m))]
461pub fn py_clebsch_gordan(
462 tags: &Bound<'_, PyTuple>,
463 j1: &Bound<'_, PyAny>,
464 m1: &Bound<'_, PyAny>,
465 j2: &Bound<'_, PyAny>,
466 m2: &Bound<'_, PyAny>,
467 j: &Bound<'_, PyAny>,
468 m: &Bound<'_, PyAny>,
469) -> PyResult<PyExpression> {
470 Ok(PyExpression(ClebschGordan::new(
471 py_tags(tags)?,
472 parse_angular_momentum(j1)?,
473 parse_projection(m1)?,
474 parse_angular_momentum(j2)?,
475 parse_projection(m2)?,
476 parse_angular_momentum(j)?,
477 parse_projection(m)?,
478 )?))
479}
480
481#[pyfunction(name = "Wigner3j", signature = (*tags, j1, m1, j2, m2, j3, m3))]
483pub fn py_wigner_3j(
484 tags: &Bound<'_, PyTuple>,
485 j1: &Bound<'_, PyAny>,
486 m1: &Bound<'_, PyAny>,
487 j2: &Bound<'_, PyAny>,
488 m2: &Bound<'_, PyAny>,
489 j3: &Bound<'_, PyAny>,
490 m3: &Bound<'_, PyAny>,
491) -> PyResult<PyExpression> {
492 Ok(PyExpression(Wigner3j::new(
493 py_tags(tags)?,
494 parse_angular_momentum(j1)?,
495 parse_projection(m1)?,
496 parse_angular_momentum(j2)?,
497 parse_projection(m2)?,
498 parse_angular_momentum(j3)?,
499 parse_projection(m3)?,
500 )?))
501}
502
503#[pyfunction(name = "PhotonSDME", signature = (*tags, helicity, helicity_prime, polarization = None))]
505pub fn py_photon_sdme(
506 tags: &Bound<'_, PyTuple>,
507 helicity: i32,
508 helicity_prime: i32,
509 polarization: Option<&PyPolarization>,
510) -> PyResult<PyExpression> {
511 let polarization = polarization
512 .map(|polarization| PhotonPolarization::Linear(Box::new(polarization.0.clone())))
513 .unwrap_or(PhotonPolarization::Unpolarized);
514 Ok(PyExpression(PhotonSDME::new(
515 py_tags(tags)?,
516 polarization,
517 PhotonHelicity::new(helicity)?,
518 PhotonHelicity::new(helicity_prime)?,
519 )?))
520}
521
522#[pyfunction(name = "PhaseSpaceFactor", signature = (*tags, recoil_mass, daughter_1_mass, daughter_2_mass, resonance_mass, mandelstam_s))]
524pub fn py_phase_space_factor(
525 tags: &Bound<'_, PyTuple>,
526 recoil_mass: &PyMass,
527 daughter_1_mass: &PyMass,
528 daughter_2_mass: &PyMass,
529 resonance_mass: &PyMass,
530 mandelstam_s: &PyMandelstam,
531) -> PyResult<PyExpression> {
532 Ok(PyExpression(PhaseSpaceFactor::new(
533 py_tags(tags)?,
534 &recoil_mass.0,
535 &daughter_1_mass.0,
536 &daughter_2_mass.0,
537 &resonance_mass.0,
538 &mandelstam_s.0,
539 )?))
540}
541
542fn py_lookup_inputs(
543 variables: Vec<PyVariable>,
544 axis_coordinates: Vec<Vec<f64>>,
545) -> LadduResult<LookupInputs> {
546 let axis_coordinates = axis_coordinates
547 .into_iter()
548 .map(LookupAxis::new)
549 .collect::<LadduResult<Vec<_>>>()?;
550 let variables = variables
551 .into_iter()
552 .map(|variable| Box::new(variable) as Box<dyn Variable>)
553 .collect();
554 Ok((variables, axis_coordinates))
555}
556
557#[pyfunction(name = "LookupTable", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
559pub fn py_lookup_table(
560 tags: &Bound<'_, PyTuple>,
561 variables: Vec<PyVariable>,
562 axis_coordinates: Vec<Vec<f64>>,
563 values: Vec<Complex64>,
564 interpolation: &str,
565 boundary_mode: &str,
566) -> PyResult<PyExpression> {
567 let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
568 Ok(PyExpression(LookupTable::new(
569 py_tags(tags)?,
570 variables,
571 axis_coordinates,
572 values,
573 interpolation.parse()?,
574 boundary_mode.parse()?,
575 )?))
576}
577
578#[pyfunction(name = "LookupTableScalar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
580pub fn py_lookup_table_scalar(
581 tags: &Bound<'_, PyTuple>,
582 variables: Vec<PyVariable>,
583 axis_coordinates: Vec<Vec<f64>>,
584 values: Vec<PyParameter>,
585 interpolation: &str,
586 boundary_mode: &str,
587) -> PyResult<PyExpression> {
588 let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
589 Ok(PyExpression(LookupTable::new_scalar(
590 py_tags(tags)?,
591 variables,
592 axis_coordinates,
593 values.into_iter().map(|value| value.0).collect(),
594 interpolation.parse()?,
595 boundary_mode.parse()?,
596 )?))
597}
598
599#[pyfunction(name = "LookupTableComplex", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
601pub fn py_lookup_table_complex(
602 tags: &Bound<'_, PyTuple>,
603 variables: Vec<PyVariable>,
604 axis_coordinates: Vec<Vec<f64>>,
605 values: Vec<(PyParameter, PyParameter)>,
606 interpolation: &str,
607 boundary_mode: &str,
608) -> PyResult<PyExpression> {
609 let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
610 Ok(PyExpression(LookupTable::new_cartesian_complex(
611 py_tags(tags)?,
612 variables,
613 axis_coordinates,
614 values
615 .into_iter()
616 .map(|(value_re, value_im)| (value_re.0, value_im.0))
617 .collect(),
618 interpolation.parse()?,
619 boundary_mode.parse()?,
620 )?))
621}
622
623#[pyfunction(name = "LookupTablePolar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
625pub fn py_lookup_table_polar(
626 tags: &Bound<'_, PyTuple>,
627 variables: Vec<PyVariable>,
628 axis_coordinates: Vec<Vec<f64>>,
629 values: Vec<(PyParameter, PyParameter)>,
630 interpolation: &str,
631 boundary_mode: &str,
632) -> PyResult<PyExpression> {
633 let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
634 Ok(PyExpression(LookupTable::new_polar_complex(
635 py_tags(tags)?,
636 variables,
637 axis_coordinates,
638 values
639 .into_iter()
640 .map(|(value_r, value_theta)| (value_r.0, value_theta.0))
641 .collect(),
642 interpolation.parse()?,
643 boundary_mode.parse()?,
644 )?))
645}
646
647#[pyfunction(name = "KopfKMatrixA0", signature = (*tags, couplings, channel, mass, seed = None))]
649pub fn py_kopf_kmatrix_a0(
650 tags: &Bound<'_, PyTuple>,
651 couplings: [[PyParameter; 2]; 2],
652 channel: PyKopfKMatrixA0Channel,
653 mass: PyMass,
654 seed: Option<usize>,
655) -> PyResult<PyExpression> {
656 Ok(PyExpression(KopfKMatrixA0::new(
657 py_tags(tags)?,
658 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
659 channel.into(),
660 &mass.0,
661 seed,
662 )?))
663}
664
665#[pyfunction(name = "KopfKMatrixA2", signature = (*tags, couplings, channel, mass, seed = None))]
667pub fn py_kopf_kmatrix_a2(
668 tags: &Bound<'_, PyTuple>,
669 couplings: [[PyParameter; 2]; 2],
670 channel: PyKopfKMatrixA2Channel,
671 mass: PyMass,
672 seed: Option<usize>,
673) -> PyResult<PyExpression> {
674 Ok(PyExpression(KopfKMatrixA2::new(
675 py_tags(tags)?,
676 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
677 channel.into(),
678 &mass.0,
679 seed,
680 )?))
681}
682
683#[pyfunction(name = "KopfKMatrixF0", signature = (*tags, couplings, channel, mass, seed = None))]
685pub fn py_kopf_kmatrix_f0(
686 tags: &Bound<'_, PyTuple>,
687 couplings: [[PyParameter; 2]; 5],
688 channel: PyKopfKMatrixF0Channel,
689 mass: PyMass,
690 seed: Option<usize>,
691) -> PyResult<PyExpression> {
692 Ok(PyExpression(KopfKMatrixF0::new(
693 py_tags(tags)?,
694 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
695 channel.into(),
696 &mass.0,
697 seed,
698 )?))
699}
700
701#[pyfunction(name = "KopfKMatrixF2", signature = (*tags, couplings, channel, mass, seed = None))]
703pub fn py_kopf_kmatrix_f2(
704 tags: &Bound<'_, PyTuple>,
705 couplings: [[PyParameter; 2]; 4],
706 channel: PyKopfKMatrixF2Channel,
707 mass: PyMass,
708 seed: Option<usize>,
709) -> PyResult<PyExpression> {
710 Ok(PyExpression(KopfKMatrixF2::new(
711 py_tags(tags)?,
712 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
713 channel.into(),
714 &mass.0,
715 seed,
716 )?))
717}
718
719#[pyfunction(name = "KopfKMatrixPi1", signature = (*tags, couplings, channel, mass))]
721pub fn py_kopf_kmatrix_pi1(
722 tags: &Bound<'_, PyTuple>,
723 couplings: [[PyParameter; 2]; 1],
724 channel: PyKopfKMatrixPi1Channel,
725 mass: PyMass,
726) -> PyResult<PyExpression> {
727 Ok(PyExpression(KopfKMatrixPi1::new(
728 py_tags(tags)?,
729 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
730 channel.into(),
731 &mass.0,
732 )?))
733}
734
735#[pyfunction(name = "KopfKMatrixRho", signature = (*tags, couplings, channel, mass))]
737pub fn py_kopf_kmatrix_rho(
738 tags: &Bound<'_, PyTuple>,
739 couplings: [[PyParameter; 2]; 2],
740 channel: PyKopfKMatrixRhoChannel,
741 mass: PyMass,
742) -> PyResult<PyExpression> {
743 Ok(PyExpression(KopfKMatrixRho::new(
744 py_tags(tags)?,
745 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
746 channel.into(),
747 &mass.0,
748 )?))
749}
750
751#[pymethods]
752impl PyExpression {
753 #[getter]
760 fn parameters(&self) -> PyParameterMap {
761 PyParameterMap(self.0.parameters())
762 }
763 #[getter]
765 fn n_free(&self) -> usize {
766 self.0.n_free()
767 }
768 #[getter]
770 fn n_fixed(&self) -> usize {
771 self.0.n_fixed()
772 }
773 #[getter]
775 fn n_parameters(&self) -> usize {
776 self.0.n_parameters()
777 }
778 fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
791 Ok(PyEvaluator(self.0.load(&dataset.0)?))
792 }
793 fn real(&self) -> PyExpression {
795 PyExpression(self.0.real())
796 }
797 fn imag(&self) -> PyExpression {
799 PyExpression(self.0.imag())
800 }
801 fn conj(&self) -> PyExpression {
803 PyExpression(self.0.conj())
804 }
805 fn norm_sqr(&self) -> PyExpression {
807 PyExpression(self.0.norm_sqr())
808 }
809 fn sqrt(&self) -> PyExpression {
811 PyExpression(self.0.sqrt())
812 }
813 fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
815 if let Ok(value) = power.extract::<i32>() {
816 Ok(PyExpression(self.0.powi(value)))
817 } else if let Ok(value) = power.extract::<f64>() {
818 Ok(PyExpression(self.0.powf(value)))
819 } else if let Ok(expression) = power.extract::<PyExpression>() {
820 Ok(PyExpression(self.0.pow(&expression.0)))
821 } else {
822 Err(PyTypeError::new_err(
823 "power must be an int, float, or Expression",
824 ))
825 }
826 }
827 fn exp(&self) -> PyExpression {
829 PyExpression(self.0.exp())
830 }
831 fn sin(&self) -> PyExpression {
833 PyExpression(self.0.sin())
834 }
835 fn cos(&self) -> PyExpression {
837 PyExpression(self.0.cos())
838 }
839 fn log(&self) -> PyExpression {
841 PyExpression(self.0.log())
842 }
843 fn cis(&self) -> PyExpression {
845 PyExpression(self.0.cis())
846 }
847 fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
849 Ok(self.0.fix_parameter(name, value)?)
850 }
851 fn free_parameter(&self, name: &str) -> PyResult<()> {
853 Ok(self.0.free_parameter(name)?)
854 }
855 fn rename_parameter(&mut self, old: &str, new: &str) -> PyResult<()> {
857 Ok(self.0.rename_parameter(old, new)?)
858 }
859 fn rename_parameters(&mut self, mapping: HashMap<String, String>) -> PyResult<()> {
861 Ok(self.0.rename_parameters(&mapping)?)
862 }
863 #[getter]
865 fn compiled_expression(&self) -> PyCompiledExpression {
866 PyCompiledExpression(self.0.compiled_expression())
867 }
868 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
869 if let Ok(other_expr) = other.extract::<PyExpression>() {
870 Ok(PyExpression(self.0.clone() + other_expr.0))
871 } else {
872 Err(PyTypeError::new_err("Unsupported operand type for +"))
873 }
874 }
875 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
876 if let Ok(other_expr) = other.extract::<PyExpression>() {
877 Ok(PyExpression(other_expr.0 + self.0.clone()))
878 } else {
879 Err(PyTypeError::new_err("Unsupported operand type for +"))
880 }
881 }
882 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
883 if let Ok(other_expr) = other.extract::<PyExpression>() {
884 Ok(PyExpression(self.0.clone() - other_expr.0))
885 } else {
886 Err(PyTypeError::new_err("Unsupported operand type for -"))
887 }
888 }
889 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
890 if let Ok(other_expr) = other.extract::<PyExpression>() {
891 Ok(PyExpression(other_expr.0 - self.0.clone()))
892 } else {
893 Err(PyTypeError::new_err("Unsupported operand type for -"))
894 }
895 }
896 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
897 if let Ok(other_expr) = other.extract::<PyExpression>() {
898 Ok(PyExpression(self.0.clone() * other_expr.0))
899 } else {
900 Err(PyTypeError::new_err("Unsupported operand type for *"))
901 }
902 }
903 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
904 if let Ok(other_expr) = other.extract::<PyExpression>() {
905 Ok(PyExpression(other_expr.0 * self.0.clone()))
906 } else {
907 Err(PyTypeError::new_err("Unsupported operand type for *"))
908 }
909 }
910 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
911 if let Ok(other_expr) = other.extract::<PyExpression>() {
912 Ok(PyExpression(self.0.clone() / other_expr.0))
913 } else {
914 Err(PyTypeError::new_err("Unsupported operand type for /"))
915 }
916 }
917 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
918 if let Ok(other_expr) = other.extract::<PyExpression>() {
919 Ok(PyExpression(other_expr.0 / self.0.clone()))
920 } else {
921 Err(PyTypeError::new_err("Unsupported operand type for /"))
922 }
923 }
924 fn __neg__(&self) -> PyExpression {
925 PyExpression(-self.0.clone())
926 }
927 fn __str__(&self) -> String {
928 format!("{}", self.0)
929 }
930 fn __repr__(&self) -> String {
931 format!("{:?}", self.0)
932 }
933
934 #[new]
935 fn new() -> Self {
936 Self(Expression::default())
937 }
938 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
939 Ok(PyBytes::new(
940 py,
941 serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
942 .map_err(LadduError::PickleError)?
943 .as_slice(),
944 ))
945 }
946 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
947 *self = Self(
948 serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
949 .map_err(LadduError::PickleError)?,
950 );
951 Ok(())
952 }
953}
954
955#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
962#[derive(Clone)]
963pub struct PyEvaluator(pub Evaluator);
964
965#[pymethods]
966impl PyEvaluator {
967 #[getter]
975 fn parameters(&self) -> PyParameterMap {
976 PyParameterMap(self.0.parameters())
977 }
978 #[getter]
980 fn n_free(&self) -> usize {
981 self.0.n_free()
982 }
983 #[getter]
985 fn n_fixed(&self) -> usize {
986 self.0.n_fixed()
987 }
988 #[getter]
990 fn n_parameters(&self) -> usize {
991 self.0.n_parameters()
992 }
993 fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
995 Ok(self.0.fix_parameter(name, value)?)
996 }
997 fn free_parameter(&self, name: &str) -> PyResult<()> {
999 Ok(self.0.free_parameter(name)?)
1000 }
1001 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
1003 Ok(self.0.rename_parameter(old, new)?)
1004 }
1005 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
1007 Ok(self.0.rename_parameters(&mapping)?)
1008 }
1009 #[pyo3(signature = (arg, *, strict=true))]
1026 fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1027 if let Ok(string_arg) = arg.extract::<String>() {
1028 if strict {
1029 self.0.activate_strict(&string_arg)?;
1030 } else {
1031 self.0.activate(&string_arg);
1032 }
1033 } else if let Ok(list_arg) = arg.cast::<PyList>() {
1034 let vec: Vec<String> = list_arg.extract()?;
1035 if strict {
1036 self.0.activate_many_strict(&vec)?;
1037 } else {
1038 self.0.activate_many(&vec);
1039 }
1040 } else {
1041 return Err(PyTypeError::new_err(
1042 "Argument must be either a string or a list of strings",
1043 ));
1044 }
1045 Ok(())
1046 }
1047 fn activate_all(&self) {
1050 self.0.activate_all();
1051 }
1052 #[pyo3(signature = (arg, *, strict=true))]
1071 fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1072 if let Ok(string_arg) = arg.extract::<String>() {
1073 if strict {
1074 self.0.deactivate_strict(&string_arg)?;
1075 } else {
1076 self.0.deactivate(&string_arg);
1077 }
1078 } else if let Ok(list_arg) = arg.cast::<PyList>() {
1079 let vec: Vec<String> = list_arg.extract()?;
1080 if strict {
1081 self.0.deactivate_many_strict(&vec)?;
1082 } else {
1083 self.0.deactivate_many(&vec);
1084 }
1085 } else {
1086 return Err(PyTypeError::new_err(
1087 "Argument must be either a string or a list of strings",
1088 ));
1089 }
1090 Ok(())
1091 }
1092 fn deactivate_all(&self) {
1095 self.0.deactivate_all();
1096 }
1097 #[pyo3(signature = (arg, *, strict=true))]
1116 fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1117 if let Ok(string_arg) = arg.extract::<String>() {
1118 if strict {
1119 self.0.isolate_strict(&string_arg)?;
1120 } else {
1121 self.0.isolate(&string_arg);
1122 }
1123 } else if let Ok(list_arg) = arg.cast::<PyList>() {
1124 let vec: Vec<String> = list_arg.extract()?;
1125 if strict {
1126 self.0.isolate_many_strict(&vec)?;
1127 } else {
1128 self.0.isolate_many(&vec);
1129 }
1130 } else {
1131 return Err(PyTypeError::new_err(
1132 "Argument must be either a string or a list of strings",
1133 ));
1134 }
1135 Ok(())
1136 }
1137
1138 #[getter]
1140 fn active_mask(&self) -> Vec<bool> {
1141 self.0.active_mask()
1142 }
1143
1144 fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
1146 self.0.set_active_mask(&mask)?;
1147 Ok(())
1148 }
1149
1150 #[getter]
1152 fn compiled_expression(&self) -> PyCompiledExpression {
1153 PyCompiledExpression(self.0.compiled_expression())
1154 }
1155
1156 #[getter]
1158 fn expression(&self) -> PyExpression {
1159 PyExpression(self.0.expression())
1160 }
1161
1162 #[pyo3(signature = (parameters, *, threads=None))]
1184 fn evaluate<'py>(
1185 &self,
1186 py: Python<'py>,
1187 parameters: Vec<f64>,
1188 threads: Option<usize>,
1189 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
1190 let values = install_with_threads(threads, || self.0.evaluate(¶meters))?;
1191 Ok(PyArray1::from_slice(py, &values?))
1192 }
1193 #[pyo3(signature = (parameters, indices, *, threads=None))]
1217 fn evaluate_batch<'py>(
1218 &self,
1219 py: Python<'py>,
1220 parameters: Vec<f64>,
1221 indices: Vec<usize>,
1222 threads: Option<usize>,
1223 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
1224 let values =
1225 install_with_threads(threads, || self.0.evaluate_batch(¶meters, &indices))?;
1226 Ok(PyArray1::from_slice(py, &values?))
1227 }
1228 #[pyo3(signature = (parameters, *, threads=None))]
1251 fn evaluate_gradient<'py>(
1252 &self,
1253 py: Python<'py>,
1254 parameters: Vec<f64>,
1255 threads: Option<usize>,
1256 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
1257 let gradients: LadduResult<_> = install_with_threads(threads, || {
1258 Ok(self
1259 .0
1260 .evaluate_gradient(¶meters)?
1261 .iter()
1262 .map(|grad| grad.data.as_vec().to_vec())
1263 .collect::<Vec<Vec<Complex64>>>())
1264 })?;
1265 Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
1266 }
1267 #[pyo3(signature = (parameters, indices, *, threads=None))]
1292 fn evaluate_gradient_batch<'py>(
1293 &self,
1294 py: Python<'py>,
1295 parameters: Vec<f64>,
1296 indices: Vec<usize>,
1297 threads: Option<usize>,
1298 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
1299 let gradients: LadduResult<_> = install_with_threads(threads, || {
1300 Ok(self
1301 .0
1302 .evaluate_gradient_batch(¶meters, &indices)?
1303 .iter()
1304 .map(|grad| grad.data.as_vec().to_vec())
1305 .collect::<Vec<Vec<Complex64>>>())
1306 })?;
1307 Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
1308 }
1309}
1310
1311#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
1318#[derive(Clone)]
1319pub struct PyCompiledExpression(pub CompiledExpression);
1320
1321#[pymethods]
1322impl PyCompiledExpression {
1323 fn __str__(&self) -> String {
1324 format!("{}", self.0)
1325 }
1326 fn __repr__(&self) -> String {
1327 format!("{:?}", self.0)
1328 }
1329}
1330
1331#[pyclass(name = "ParameterMap", module = "laddu", from_py_object)]
1332#[derive(Clone)]
1333pub struct PyParameterMap(pub ParameterMap);
1334
1335#[pymethods]
1336impl PyParameterMap {
1337 #[getter]
1338 fn names<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
1339 PyTuple::new(py, self.0.names())
1340 }
1341
1342 #[getter]
1343 fn free(&self) -> PyParameterMap {
1344 PyParameterMap(self.0.free())
1345 }
1346
1347 #[getter]
1348 fn fixed(&self) -> PyParameterMap {
1349 PyParameterMap(self.0.fixed())
1350 }
1351
1352 fn contains(&self, name: &str) -> bool {
1353 self.0.contains_key(name)
1354 }
1355
1356 fn __contains__(&self, name: &str) -> bool {
1357 self.contains(name)
1358 }
1359
1360 fn __len__(&self) -> usize {
1361 self.0.len()
1362 }
1363
1364 fn __bool__(&self) -> bool {
1365 !self.0.is_empty()
1366 }
1367
1368 fn __getitem__(&self, index: &Bound<'_, PyAny>) -> PyResult<PyParameter> {
1369 if let Ok(name) = index.extract::<String>() {
1370 return self
1371 .0
1372 .get(&name)
1373 .cloned()
1374 .map(PyParameter)
1375 .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err(name));
1376 }
1377 if let Ok(index) = index.extract::<isize>() {
1378 let len = self.0.len() as isize;
1379 let normalized = if index < 0 { len + index } else { index };
1380 if normalized < 0 || normalized >= len {
1381 return Err(pyo3::exceptions::PyIndexError::new_err(format!(
1382 "parameter index out of range: {index}"
1383 )));
1384 }
1385 return Ok(PyParameter(self.0[normalized as usize].clone()));
1386 }
1387 Err(pyo3::exceptions::PyTypeError::new_err(
1388 "ParameterMap indices must be str or int",
1389 ))
1390 }
1391
1392 fn __str__(&self) -> String {
1393 self.0.to_string()
1394 }
1395
1396 fn __repr__(&self) -> String {
1397 format!("{:?}", self.0)
1398 }
1399
1400 fn index(&self, name: &str) -> PyResult<usize> {
1401 self.0
1402 .index(name)
1403 .ok_or_else(|| PyValueError::new_err(format!("parameter not found: {name}")))
1404 }
1405
1406 fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
1407 let params: Vec<PyParameter> = self.0.clone().into_iter().map(PyParameter).collect();
1408 let list = PyList::new(py, params)?;
1409 PyIterator::from_object(&list)
1410 }
1411}
1412
1413#[pyclass(name = "Parameter", module = "laddu", from_py_object)]
1414#[derive(Clone)]
1415pub struct PyParameter(pub Parameter);
1416
1417#[pymethods]
1418impl PyParameter {
1419 #[getter]
1420 fn name(&self) -> String {
1421 self.0.name()
1422 }
1423 #[getter]
1424 fn fixed(&self) -> Option<f64> {
1425 self.0.fixed()
1426 }
1427 #[getter]
1428 fn initial(&self) -> Option<f64> {
1429 self.0.initial()
1430 }
1431 #[getter]
1432 fn bounds(&self) -> (Option<f64>, Option<f64>) {
1433 self.0.bounds()
1434 }
1435 #[getter]
1436 fn unit(&self) -> Option<String> {
1437 self.0.unit()
1438 }
1439 #[getter]
1440 fn latex(&self) -> Option<String> {
1441 self.0.latex()
1442 }
1443 #[getter]
1444 fn description(&self) -> Option<String> {
1445 self.0.description()
1446 }
1447}
1448
1449#[pyfunction(name = "parameter", signature = (name, fixed=None, *, initial=None, bounds=(None, None), unit=None, latex=None, description=None))]
1481pub fn py_parameter(
1482 name: &str,
1483 fixed: Option<f64>,
1484 initial: Option<f64>,
1485 bounds: (Option<f64>, Option<f64>),
1486 unit: Option<&str>,
1487 latex: Option<&str>,
1488 description: Option<&str>,
1489) -> PyParameter {
1490 let par = Parameter::new(name);
1491 if let Some(value) = initial {
1492 par.set_initial(value);
1493 }
1494 if let Some(value) = fixed {
1495 par.set_fixed_value(Some(value)); }
1497 par.set_bounds(bounds.0, bounds.1);
1498 if let Some(unit) = unit {
1499 par.set_unit(unit);
1500 }
1501 if let Some(latex) = latex {
1502 par.set_latex(latex);
1503 }
1504 if let Some(description) = description {
1505 par.set_description(description);
1506 }
1507 PyParameter(par)
1508}
1509
1510#[pyfunction(name = "TestAmplitude", signature = (*tags, re, im))]
1512pub fn py_test_amplitude(
1513 tags: &Bound<'_, PyTuple>,
1514 re: PyParameter,
1515 im: PyParameter,
1516) -> PyResult<PyExpression> {
1517 Ok(PyExpression(TestAmplitude::new(
1518 py_tags(tags)?,
1519 re.0,
1520 im.0,
1521 )?))
1522}