1use std::{array, collections::HashMap};
2
3use laddu_amplitudes::{
4 angular::{
5 BlattWeisskopf, ClebschGordan, PhotonHelicity, PhotonPolarization, PhotonSDME, PolPhase,
6 Wigner3j, WignerD, Ylm, Zlm,
7 },
8 kmatrix::{
9 KopfKMatrixA0, KopfKMatrixA0Channel, KopfKMatrixA2, KopfKMatrixA2Channel, KopfKMatrixF0,
10 KopfKMatrixF0Channel, KopfKMatrixF2, KopfKMatrixF2Channel, KopfKMatrixPi1,
11 KopfKMatrixPi1Channel, KopfKMatrixRho, KopfKMatrixRhoChannel,
12 },
13 lookup::{LookupAxis, LookupTable},
14 resonance::{BreitWigner, BreitWignerNonRelativistic, Flatte, PhaseSpaceFactor, Voigt},
15 scalar::{ComplexScalar, PolarComplexScalar, Scalar, VariableScalar},
16};
17use laddu_core::{
18 amplitude::{Evaluator, Expression, Parameter, ParameterMap, TestAmplitude},
19 math::{BarrierKind, Sheet, QR_DEFAULT},
20 traits::Variable,
21 CompiledExpression, LadduError, LadduResult, ThreadPoolManager,
22};
23use num::complex::Complex64;
24use numpy::{PyArray1, PyArray2};
25use pyo3::{
26 exceptions::{PyTypeError, PyValueError},
27 prelude::*,
28 types::{PyAny, PyBytes, PyIterator, PyList, PyTuple},
29};
30
31use crate::{
32 data::PyDataset,
33 quantum::angular_momentum::{
34 parse_angular_momentum, parse_orbital_angular_momentum, parse_projection,
35 },
36 variables::{PyAngles, PyDecay, PyMandelstam, PyMass, PyPolarization, PyVariable},
37};
38
39type LookupInputs = (Vec<Box<dyn Variable>>, Vec<LookupAxis>);
40
41macro_rules! py_kmatrix_channel {
42 ($py_name:ident, $python_name:literal, $rust_name:path { $($variant:ident),+ $(,)? }) => {
43 #[pyclass(eq, name = $python_name, module = "laddu", from_py_object)]
44 #[derive(Clone, PartialEq)]
45 pub enum $py_name {
46 $($variant,)+
47 }
48
49 impl From<$py_name> for $rust_name {
50 fn from(value: $py_name) -> Self {
51 match value {
52 $( $py_name::$variant => Self::$variant, )+
53 }
54 }
55 }
56 };
57}
58
59py_kmatrix_channel!(
60 PyKopfKMatrixA0Channel,
61 "KopfKMatrixA0Channel",
62 KopfKMatrixA0Channel { PiEta, KKbar }
63);
64py_kmatrix_channel!(
65 PyKopfKMatrixA2Channel,
66 "KopfKMatrixA2Channel",
67 KopfKMatrixA2Channel {
68 PiEta,
69 KKbar,
70 PiEtaPrime
71 }
72);
73py_kmatrix_channel!(
74 PyKopfKMatrixF0Channel,
75 "KopfKMatrixF0Channel",
76 KopfKMatrixF0Channel {
77 PiPi,
78 FourPi,
79 KKbar,
80 EtaEta,
81 EtaEtaPrime
82 }
83);
84py_kmatrix_channel!(
85 PyKopfKMatrixF2Channel,
86 "KopfKMatrixF2Channel",
87 KopfKMatrixF2Channel {
88 PiPi,
89 FourPi,
90 KKbar,
91 EtaEta
92 }
93);
94py_kmatrix_channel!(
95 PyKopfKMatrixPi1Channel,
96 "KopfKMatrixPi1Channel",
97 KopfKMatrixPi1Channel { PiEta, PiEtaPrime }
98);
99py_kmatrix_channel!(
100 PyKopfKMatrixRhoChannel,
101 "KopfKMatrixRhoChannel",
102 KopfKMatrixRhoChannel {
103 PiPi,
104 FourPi,
105 KKbar
106 }
107);
108
109fn install_with_threads<R: Send>(
110 threads: Option<usize>,
111 op: impl FnOnce() -> R + Send,
112) -> LadduResult<R> {
113 ThreadPoolManager::shared().install(threads, op)
114}
115
116pub(crate) fn py_tags(tags: &Bound<'_, PyTuple>) -> PyResult<Vec<String>> {
117 tags.iter()
118 .map(|tag| tag.extract::<String>())
119 .collect::<PyResult<Vec<_>>>()
120}
121
122#[pyclass(name = "Expression", module = "laddu", skip_from_py_object)]
125#[derive(Clone)]
126pub struct PyExpression(pub Expression);
127
128impl<'py> FromPyObject<'_, 'py> for PyExpression {
129 type Error = PyErr;
130
131 fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
132 if let Ok(obj) = obj.cast::<PyExpression>() {
133 Ok(obj.borrow().clone())
134 } else if let Ok(obj) = obj.extract::<f64>() {
135 Ok(Self(obj.into()))
136 } else if let Ok(obj) = obj.extract::<Complex64>() {
137 Ok(Self(obj.into()))
138 } else {
139 Err(PyTypeError::new_err("Failed to extract Expression"))
140 }
141 }
142}
143
144#[pyfunction(name = "expr_sum")]
147pub fn py_expr_sum(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
148 if terms.is_empty() {
149 return Ok(PyExpression(Expression::zero()));
150 }
151 if terms.len() == 1 {
152 let term = &terms[0];
153 if let Ok(expression) = term.extract::<PyExpression>() {
154 return Ok(expression);
155 }
156 return Err(PyTypeError::new_err("Item is not a PyExpression"));
157 }
158 let mut iter = terms.iter();
159 let Some(first_term) = iter.next() else {
160 return Ok(PyExpression(Expression::zero()));
161 };
162 let PyExpression(mut summation) = first_term
163 .extract::<PyExpression>()
164 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
165 for term in iter {
166 let PyExpression(expr) = term
167 .extract::<PyExpression>()
168 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
169 summation = summation + expr;
170 }
171 Ok(PyExpression(summation))
172}
173
174#[pyfunction(name = "expr_product")]
177pub fn py_expr_product(terms: Vec<Bound<'_, PyAny>>) -> PyResult<PyExpression> {
178 if terms.is_empty() {
179 return Ok(PyExpression(Expression::one()));
180 }
181 if terms.len() == 1 {
182 let term = &terms[0];
183 if let Ok(expression) = term.extract::<PyExpression>() {
184 return Ok(expression);
185 }
186 return Err(PyTypeError::new_err("Item is not a PyExpression"));
187 }
188 let mut iter = terms.iter();
189 let Some(first_term) = iter.next() else {
190 return Ok(PyExpression(Expression::one()));
191 };
192 let PyExpression(mut product) = first_term
193 .extract::<PyExpression>()
194 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
195 for term in iter {
196 let PyExpression(expr) = term
197 .extract::<PyExpression>()
198 .map_err(|_| PyTypeError::new_err("Elements must be PyExpression"))?;
199 product = product * expr;
200 }
201 Ok(PyExpression(product))
202}
203
204#[pyfunction(name = "Zero")]
207pub fn py_expr_zero() -> PyExpression {
208 PyExpression(Expression::zero())
209}
210
211#[pyfunction(name = "One")]
214pub fn py_expr_one() -> PyExpression {
215 PyExpression(Expression::one())
216}
217
218#[pyfunction(name = "Scalar", signature = (*tags, value))]
220pub fn py_scalar(tags: &Bound<'_, PyTuple>, value: PyParameter) -> PyResult<PyExpression> {
221 Ok(PyExpression(Scalar::new(py_tags(tags)?, value.0)?))
222}
223
224#[pyfunction(name = "VariableScalar", signature = (*tags, variable))]
226pub fn py_variable_scalar(
227 tags: &Bound<'_, PyTuple>,
228 variable: Bound<'_, PyAny>,
229) -> PyResult<PyExpression> {
230 let variable = variable.extract::<PyVariable>()?;
231 Ok(PyExpression(VariableScalar::new(
232 py_tags(tags)?,
233 &variable,
234 )?))
235}
236
237#[pyfunction(name = "ComplexScalar", signature = (*tags, re, im))]
239pub fn py_complex_scalar(
240 tags: &Bound<'_, PyTuple>,
241 re: PyParameter,
242 im: PyParameter,
243) -> PyResult<PyExpression> {
244 Ok(PyExpression(ComplexScalar::new(
245 py_tags(tags)?,
246 re.0,
247 im.0,
248 )?))
249}
250
251#[pyfunction(name = "PolarComplexScalar", signature = (*tags, r, theta))]
253pub fn py_polar_complex_scalar(
254 tags: &Bound<'_, PyTuple>,
255 r: PyParameter,
256 theta: PyParameter,
257) -> PyResult<PyExpression> {
258 Ok(PyExpression(PolarComplexScalar::new(
259 py_tags(tags)?,
260 r.0,
261 theta.0,
262 )?))
263}
264
265#[pyfunction(name = "BreitWigner", signature = (*tags, mass, width, l, daughter_1_mass, daughter_2_mass, resonance_mass, barrier_factors=true))]
267#[allow(clippy::too_many_arguments)]
268pub fn py_breit_wigner(
269 tags: &Bound<'_, PyTuple>,
270 mass: PyParameter,
271 width: PyParameter,
272 l: &Bound<'_, PyAny>,
273 daughter_1_mass: &PyMass,
274 daughter_2_mass: &PyMass,
275 resonance_mass: &PyMass,
276 barrier_factors: bool,
277) -> PyResult<PyExpression> {
278 let l = parse_orbital_angular_momentum(l)?.value() as usize;
279 if barrier_factors {
280 Ok(PyExpression(BreitWigner::new(
281 py_tags(tags)?,
282 mass.0,
283 width.0,
284 l,
285 &daughter_1_mass.0,
286 &daughter_2_mass.0,
287 &resonance_mass.0,
288 )?))
289 } else {
290 Ok(PyExpression(BreitWigner::new_without_barrier_factors(
291 py_tags(tags)?,
292 mass.0,
293 width.0,
294 l,
295 &daughter_1_mass.0,
296 &daughter_2_mass.0,
297 &resonance_mass.0,
298 )?))
299 }
300}
301
302#[pyfunction(name = "BreitWignerNonRelativistic", signature = (*tags, mass, width, resonance_mass))]
304pub fn py_breit_wigner_non_relativistic(
305 tags: &Bound<'_, PyTuple>,
306 mass: PyParameter,
307 width: PyParameter,
308 resonance_mass: &PyMass,
309) -> PyResult<PyExpression> {
310 Ok(PyExpression(BreitWignerNonRelativistic::new(
311 py_tags(tags)?,
312 mass.0,
313 width.0,
314 &resonance_mass.0,
315 )?))
316}
317
318#[pyfunction(name = "Flatte", signature = (*tags, mass, observed_channel_coupling, alternate_channel_coupling, observed_channel_daughter_masses, alternate_channel_daughter_masses, resonance_mass))]
320pub fn py_flatte(
321 tags: &Bound<'_, PyTuple>,
322 mass: PyParameter,
323 observed_channel_coupling: PyParameter,
324 alternate_channel_coupling: PyParameter,
325 observed_channel_daughter_masses: (PyMass, PyMass),
326 alternate_channel_daughter_masses: (f64, f64),
327 resonance_mass: &PyMass,
328) -> PyResult<PyExpression> {
329 Ok(PyExpression(Flatte::new(
330 py_tags(tags)?,
331 mass.0,
332 observed_channel_coupling.0,
333 alternate_channel_coupling.0,
334 (
335 &observed_channel_daughter_masses.0 .0,
336 &observed_channel_daughter_masses.1 .0,
337 ),
338 alternate_channel_daughter_masses,
339 &resonance_mass.0,
340 )?))
341}
342
343#[pyfunction(name = "Voigt", signature = (*tags, mass, width, sigma, resonance_mass))]
345pub fn py_voigt(
346 tags: &Bound<'_, PyTuple>,
347 mass: PyParameter,
348 width: PyParameter,
349 sigma: PyParameter,
350 resonance_mass: &PyMass,
351) -> PyResult<PyExpression> {
352 Ok(PyExpression(Voigt::new(
353 py_tags(tags)?,
354 mass.0,
355 width.0,
356 sigma.0,
357 &resonance_mass.0,
358 )?))
359}
360
361#[pyfunction(name = "Ylm", signature = (*tags, l, m, angles))]
363pub fn py_ylm(
364 tags: &Bound<'_, PyTuple>,
365 l: usize,
366 m: isize,
367 angles: &PyAngles,
368) -> PyResult<PyExpression> {
369 Ok(PyExpression(Ylm::new(py_tags(tags)?, l, m, &angles.0)?))
370}
371
372#[pyfunction(name = "Zlm", signature = (*tags, l, m, r, angles, polarization))]
374pub fn py_zlm(
375 tags: &Bound<'_, PyTuple>,
376 l: usize,
377 m: isize,
378 r: &str,
379 angles: &PyAngles,
380 polarization: &PyPolarization,
381) -> PyResult<PyExpression> {
382 Ok(PyExpression(Zlm::new(
383 py_tags(tags)?,
384 l,
385 m,
386 r.parse()?,
387 &angles.0,
388 &polarization.0,
389 )?))
390}
391
392#[pyfunction(name = "PolPhase", signature = (*tags, polarization))]
394pub fn py_polphase(
395 tags: &Bound<'_, PyTuple>,
396 polarization: &PyPolarization,
397) -> PyResult<PyExpression> {
398 Ok(PyExpression(PolPhase::new(
399 py_tags(tags)?,
400 &polarization.0,
401 )?))
402}
403
404#[pyfunction(name = "WignerD", signature = (*tags, spin, row_projection, column_projection, angles))]
406pub fn py_wigner_d(
407 tags: &Bound<'_, PyTuple>,
408 spin: &Bound<'_, PyAny>,
409 row_projection: &Bound<'_, PyAny>,
410 column_projection: &Bound<'_, PyAny>,
411 angles: &PyAngles,
412) -> PyResult<PyExpression> {
413 Ok(PyExpression(WignerD::new(
414 py_tags(tags)?,
415 parse_angular_momentum(spin)?,
416 parse_projection(row_projection)?,
417 parse_projection(column_projection)?,
418 &angles.0,
419 )?))
420}
421
422#[pyfunction(name = "BlattWeisskopf", signature = (*tags, decay, l, reference_mass, q_r = QR_DEFAULT, sheet = "physical", kind = "full"))]
424pub fn py_blatt_weisskopf(
425 tags: &Bound<'_, PyTuple>,
426 decay: &PyDecay,
427 l: &Bound<'_, PyAny>,
428 reference_mass: f64,
429 q_r: f64,
430 sheet: &str,
431 kind: &str,
432) -> PyResult<PyExpression> {
433 let sheet = match sheet.to_ascii_lowercase().as_str() {
434 "physical" => Sheet::Physical,
435 "unphysical" => Sheet::Unphysical,
436 _ => {
437 return Err(PyValueError::new_err(
438 "sheet must be 'physical' or 'unphysical'",
439 ));
440 }
441 };
442 let kind = match kind.to_ascii_lowercase().as_str() {
443 "full" => BarrierKind::Full,
444 "tensor" => BarrierKind::Tensor,
445 _ => {
446 return Err(PyValueError::new_err("kind must be 'full' or 'tensor'"));
447 }
448 };
449 Ok(PyExpression(BlattWeisskopf::new(
450 py_tags(tags)?,
451 &decay.0,
452 parse_orbital_angular_momentum(l)?,
453 reference_mass,
454 q_r,
455 sheet,
456 kind,
457 )?))
458}
459
460#[pyfunction(name = "ClebschGordan", signature = (*tags, j1, m1, j2, m2, j, m))]
462pub fn py_clebsch_gordan(
463 tags: &Bound<'_, PyTuple>,
464 j1: &Bound<'_, PyAny>,
465 m1: &Bound<'_, PyAny>,
466 j2: &Bound<'_, PyAny>,
467 m2: &Bound<'_, PyAny>,
468 j: &Bound<'_, PyAny>,
469 m: &Bound<'_, PyAny>,
470) -> PyResult<PyExpression> {
471 Ok(PyExpression(ClebschGordan::new(
472 py_tags(tags)?,
473 parse_angular_momentum(j1)?,
474 parse_projection(m1)?,
475 parse_angular_momentum(j2)?,
476 parse_projection(m2)?,
477 parse_angular_momentum(j)?,
478 parse_projection(m)?,
479 )?))
480}
481
482#[pyfunction(name = "Wigner3j", signature = (*tags, j1, m1, j2, m2, j3, m3))]
484pub fn py_wigner_3j(
485 tags: &Bound<'_, PyTuple>,
486 j1: &Bound<'_, PyAny>,
487 m1: &Bound<'_, PyAny>,
488 j2: &Bound<'_, PyAny>,
489 m2: &Bound<'_, PyAny>,
490 j3: &Bound<'_, PyAny>,
491 m3: &Bound<'_, PyAny>,
492) -> PyResult<PyExpression> {
493 Ok(PyExpression(Wigner3j::new(
494 py_tags(tags)?,
495 parse_angular_momentum(j1)?,
496 parse_projection(m1)?,
497 parse_angular_momentum(j2)?,
498 parse_projection(m2)?,
499 parse_angular_momentum(j3)?,
500 parse_projection(m3)?,
501 )?))
502}
503
504#[pyfunction(name = "PhotonSDME", signature = (*tags, helicity, helicity_prime, polarization = None))]
506pub fn py_photon_sdme(
507 tags: &Bound<'_, PyTuple>,
508 helicity: i32,
509 helicity_prime: i32,
510 polarization: Option<&PyPolarization>,
511) -> PyResult<PyExpression> {
512 let polarization = polarization
513 .map(|polarization| PhotonPolarization::Linear(Box::new(polarization.0.clone())))
514 .unwrap_or(PhotonPolarization::Unpolarized);
515 Ok(PyExpression(PhotonSDME::new(
516 py_tags(tags)?,
517 polarization,
518 PhotonHelicity::new(helicity)?,
519 PhotonHelicity::new(helicity_prime)?,
520 )?))
521}
522
523#[pyfunction(name = "PhaseSpaceFactor", signature = (*tags, recoil_mass, daughter_1_mass, daughter_2_mass, resonance_mass, mandelstam_s))]
525pub fn py_phase_space_factor(
526 tags: &Bound<'_, PyTuple>,
527 recoil_mass: &PyMass,
528 daughter_1_mass: &PyMass,
529 daughter_2_mass: &PyMass,
530 resonance_mass: &PyMass,
531 mandelstam_s: &PyMandelstam,
532) -> PyResult<PyExpression> {
533 Ok(PyExpression(PhaseSpaceFactor::new(
534 py_tags(tags)?,
535 &recoil_mass.0,
536 &daughter_1_mass.0,
537 &daughter_2_mass.0,
538 &resonance_mass.0,
539 &mandelstam_s.0,
540 )?))
541}
542
543fn py_lookup_inputs(
544 variables: Vec<PyVariable>,
545 axis_coordinates: Vec<Vec<f64>>,
546) -> LadduResult<LookupInputs> {
547 let axis_coordinates = axis_coordinates
548 .into_iter()
549 .map(LookupAxis::new)
550 .collect::<LadduResult<Vec<_>>>()?;
551 let variables = variables
552 .into_iter()
553 .map(|variable| Box::new(variable) as Box<dyn Variable>)
554 .collect();
555 Ok((variables, axis_coordinates))
556}
557
558#[pyfunction(name = "LookupTable", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
560pub fn py_lookup_table(
561 tags: &Bound<'_, PyTuple>,
562 variables: Vec<PyVariable>,
563 axis_coordinates: Vec<Vec<f64>>,
564 values: Vec<Complex64>,
565 interpolation: &str,
566 boundary_mode: &str,
567) -> PyResult<PyExpression> {
568 let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
569 Ok(PyExpression(LookupTable::new(
570 py_tags(tags)?,
571 variables,
572 axis_coordinates,
573 values,
574 interpolation.parse()?,
575 boundary_mode.parse()?,
576 )?))
577}
578
579#[pyfunction(name = "LookupTableScalar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
581pub fn py_lookup_table_scalar(
582 tags: &Bound<'_, PyTuple>,
583 variables: Vec<PyVariable>,
584 axis_coordinates: Vec<Vec<f64>>,
585 values: Vec<PyParameter>,
586 interpolation: &str,
587 boundary_mode: &str,
588) -> PyResult<PyExpression> {
589 let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
590 Ok(PyExpression(LookupTable::new_scalar(
591 py_tags(tags)?,
592 variables,
593 axis_coordinates,
594 values.into_iter().map(|value| value.0).collect(),
595 interpolation.parse()?,
596 boundary_mode.parse()?,
597 )?))
598}
599
600#[pyfunction(name = "LookupTableComplex", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
602pub fn py_lookup_table_complex(
603 tags: &Bound<'_, PyTuple>,
604 variables: Vec<PyVariable>,
605 axis_coordinates: Vec<Vec<f64>>,
606 values: Vec<(PyParameter, PyParameter)>,
607 interpolation: &str,
608 boundary_mode: &str,
609) -> PyResult<PyExpression> {
610 let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
611 Ok(PyExpression(LookupTable::new_cartesian_complex(
612 py_tags(tags)?,
613 variables,
614 axis_coordinates,
615 values
616 .into_iter()
617 .map(|(value_re, value_im)| (value_re.0, value_im.0))
618 .collect(),
619 interpolation.parse()?,
620 boundary_mode.parse()?,
621 )?))
622}
623
624#[pyfunction(name = "LookupTablePolar", signature = (*tags, variables, axis_coordinates, values, interpolation = "nearest", boundary_mode = "zero"))]
626pub fn py_lookup_table_polar(
627 tags: &Bound<'_, PyTuple>,
628 variables: Vec<PyVariable>,
629 axis_coordinates: Vec<Vec<f64>>,
630 values: Vec<(PyParameter, PyParameter)>,
631 interpolation: &str,
632 boundary_mode: &str,
633) -> PyResult<PyExpression> {
634 let (variables, axis_coordinates) = py_lookup_inputs(variables, axis_coordinates)?;
635 Ok(PyExpression(LookupTable::new_polar_complex(
636 py_tags(tags)?,
637 variables,
638 axis_coordinates,
639 values
640 .into_iter()
641 .map(|(value_r, value_theta)| (value_r.0, value_theta.0))
642 .collect(),
643 interpolation.parse()?,
644 boundary_mode.parse()?,
645 )?))
646}
647
648#[pyfunction(name = "KopfKMatrixA0", signature = (*tags, couplings, channel, mass, seed = None))]
650pub fn py_kopf_kmatrix_a0(
651 tags: &Bound<'_, PyTuple>,
652 couplings: [[PyParameter; 2]; 2],
653 channel: PyKopfKMatrixA0Channel,
654 mass: PyMass,
655 seed: Option<usize>,
656) -> PyResult<PyExpression> {
657 Ok(PyExpression(KopfKMatrixA0::new(
658 py_tags(tags)?,
659 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
660 channel.into(),
661 &mass.0,
662 seed,
663 )?))
664}
665
666#[pyfunction(name = "KopfKMatrixA2", signature = (*tags, couplings, channel, mass, seed = None))]
668pub fn py_kopf_kmatrix_a2(
669 tags: &Bound<'_, PyTuple>,
670 couplings: [[PyParameter; 2]; 2],
671 channel: PyKopfKMatrixA2Channel,
672 mass: PyMass,
673 seed: Option<usize>,
674) -> PyResult<PyExpression> {
675 Ok(PyExpression(KopfKMatrixA2::new(
676 py_tags(tags)?,
677 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
678 channel.into(),
679 &mass.0,
680 seed,
681 )?))
682}
683
684#[pyfunction(name = "KopfKMatrixF0", signature = (*tags, couplings, channel, mass, seed = None))]
686pub fn py_kopf_kmatrix_f0(
687 tags: &Bound<'_, PyTuple>,
688 couplings: [[PyParameter; 2]; 5],
689 channel: PyKopfKMatrixF0Channel,
690 mass: PyMass,
691 seed: Option<usize>,
692) -> PyResult<PyExpression> {
693 Ok(PyExpression(KopfKMatrixF0::new(
694 py_tags(tags)?,
695 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
696 channel.into(),
697 &mass.0,
698 seed,
699 )?))
700}
701
702#[pyfunction(name = "KopfKMatrixF2", signature = (*tags, couplings, channel, mass, seed = None))]
704pub fn py_kopf_kmatrix_f2(
705 tags: &Bound<'_, PyTuple>,
706 couplings: [[PyParameter; 2]; 4],
707 channel: PyKopfKMatrixF2Channel,
708 mass: PyMass,
709 seed: Option<usize>,
710) -> PyResult<PyExpression> {
711 Ok(PyExpression(KopfKMatrixF2::new(
712 py_tags(tags)?,
713 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
714 channel.into(),
715 &mass.0,
716 seed,
717 )?))
718}
719
720#[pyfunction(name = "KopfKMatrixPi1", signature = (*tags, couplings, channel, mass))]
722pub fn py_kopf_kmatrix_pi1(
723 tags: &Bound<'_, PyTuple>,
724 couplings: [[PyParameter; 2]; 1],
725 channel: PyKopfKMatrixPi1Channel,
726 mass: PyMass,
727) -> PyResult<PyExpression> {
728 Ok(PyExpression(KopfKMatrixPi1::new(
729 py_tags(tags)?,
730 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
731 channel.into(),
732 &mass.0,
733 )?))
734}
735
736#[pyfunction(name = "KopfKMatrixRho", signature = (*tags, couplings, channel, mass))]
738pub fn py_kopf_kmatrix_rho(
739 tags: &Bound<'_, PyTuple>,
740 couplings: [[PyParameter; 2]; 2],
741 channel: PyKopfKMatrixRhoChannel,
742 mass: PyMass,
743) -> PyResult<PyExpression> {
744 Ok(PyExpression(KopfKMatrixRho::new(
745 py_tags(tags)?,
746 array::from_fn(|i| array::from_fn(|j| couplings[i][j].clone().0)),
747 channel.into(),
748 &mass.0,
749 )?))
750}
751
752#[pymethods]
753impl PyExpression {
754 #[getter]
761 fn parameters(&self) -> PyParameterMap {
762 PyParameterMap(self.0.parameters())
763 }
764 #[getter]
766 fn n_free(&self) -> usize {
767 self.0.n_free()
768 }
769 #[getter]
771 fn n_fixed(&self) -> usize {
772 self.0.n_fixed()
773 }
774 #[getter]
776 fn n_parameters(&self) -> usize {
777 self.0.n_parameters()
778 }
779 fn load(&self, dataset: &PyDataset) -> PyResult<PyEvaluator> {
792 Ok(PyEvaluator(self.0.load(&dataset.0)?))
793 }
794 fn real(&self) -> PyExpression {
796 PyExpression(self.0.real())
797 }
798 fn imag(&self) -> PyExpression {
800 PyExpression(self.0.imag())
801 }
802 fn conj(&self) -> PyExpression {
804 PyExpression(self.0.conj())
805 }
806 fn norm_sqr(&self) -> PyExpression {
808 PyExpression(self.0.norm_sqr())
809 }
810 fn sqrt(&self) -> PyExpression {
812 PyExpression(self.0.sqrt())
813 }
814 fn power(&self, power: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
816 if let Ok(value) = power.extract::<i32>() {
817 Ok(PyExpression(self.0.powi(value)))
818 } else if let Ok(value) = power.extract::<f64>() {
819 Ok(PyExpression(self.0.powf(value)))
820 } else if let Ok(expression) = power.extract::<PyExpression>() {
821 Ok(PyExpression(self.0.pow(&expression.0)))
822 } else {
823 Err(PyTypeError::new_err(
824 "power must be an int, float, or Expression",
825 ))
826 }
827 }
828 fn exp(&self) -> PyExpression {
830 PyExpression(self.0.exp())
831 }
832 fn sin(&self) -> PyExpression {
834 PyExpression(self.0.sin())
835 }
836 fn cos(&self) -> PyExpression {
838 PyExpression(self.0.cos())
839 }
840 fn log(&self) -> PyExpression {
842 PyExpression(self.0.log())
843 }
844 fn cis(&self) -> PyExpression {
846 PyExpression(self.0.cis())
847 }
848 fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
850 Ok(self.0.fix_parameter(name, value)?)
851 }
852 fn free_parameter(&self, name: &str) -> PyResult<()> {
854 Ok(self.0.free_parameter(name)?)
855 }
856 fn rename_parameter(&mut self, old: &str, new: &str) -> PyResult<()> {
858 Ok(self.0.rename_parameter(old, new)?)
859 }
860 fn rename_parameters(&mut self, mapping: HashMap<String, String>) -> PyResult<()> {
862 Ok(self.0.rename_parameters(&mapping)?)
863 }
864 #[getter]
866 fn compiled_expression(&self) -> PyCompiledExpression {
867 PyCompiledExpression(self.0.compiled_expression())
868 }
869 fn __add__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
870 if let Ok(other_expr) = other.extract::<PyExpression>() {
871 Ok(PyExpression(self.0.clone() + other_expr.0))
872 } else {
873 Err(PyTypeError::new_err("Unsupported operand type for +"))
874 }
875 }
876 fn __radd__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
877 if let Ok(other_expr) = other.extract::<PyExpression>() {
878 Ok(PyExpression(other_expr.0 + self.0.clone()))
879 } else {
880 Err(PyTypeError::new_err("Unsupported operand type for +"))
881 }
882 }
883 fn __sub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
884 if let Ok(other_expr) = other.extract::<PyExpression>() {
885 Ok(PyExpression(self.0.clone() - other_expr.0))
886 } else {
887 Err(PyTypeError::new_err("Unsupported operand type for -"))
888 }
889 }
890 fn __rsub__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
891 if let Ok(other_expr) = other.extract::<PyExpression>() {
892 Ok(PyExpression(other_expr.0 - self.0.clone()))
893 } else {
894 Err(PyTypeError::new_err("Unsupported operand type for -"))
895 }
896 }
897 fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
898 if let Ok(other_expr) = other.extract::<PyExpression>() {
899 Ok(PyExpression(self.0.clone() * other_expr.0))
900 } else {
901 Err(PyTypeError::new_err("Unsupported operand type for *"))
902 }
903 }
904 fn __rmul__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
905 if let Ok(other_expr) = other.extract::<PyExpression>() {
906 Ok(PyExpression(other_expr.0 * self.0.clone()))
907 } else {
908 Err(PyTypeError::new_err("Unsupported operand type for *"))
909 }
910 }
911 fn __truediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
912 if let Ok(other_expr) = other.extract::<PyExpression>() {
913 Ok(PyExpression(self.0.clone() / other_expr.0))
914 } else {
915 Err(PyTypeError::new_err("Unsupported operand type for /"))
916 }
917 }
918 fn __rtruediv__(&self, other: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
919 if let Ok(other_expr) = other.extract::<PyExpression>() {
920 Ok(PyExpression(other_expr.0 / self.0.clone()))
921 } else {
922 Err(PyTypeError::new_err("Unsupported operand type for /"))
923 }
924 }
925 fn __neg__(&self) -> PyExpression {
926 PyExpression(-self.0.clone())
927 }
928 fn __str__(&self) -> String {
929 format!("{}", self.0)
930 }
931 fn __repr__(&self) -> String {
932 format!("{:?}", self.0)
933 }
934
935 #[new]
936 fn new() -> Self {
937 Self(Expression::default())
938 }
939 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
940 Ok(PyBytes::new(
941 py,
942 serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
943 .map_err(LadduError::PickleError)?
944 .as_slice(),
945 ))
946 }
947 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
948 *self = Self(
949 serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
950 .map_err(LadduError::PickleError)?,
951 );
952 Ok(())
953 }
954}
955
956#[pyclass(name = "Evaluator", module = "laddu", from_py_object)]
963#[derive(Clone)]
964pub struct PyEvaluator(pub Evaluator);
965
966#[pymethods]
967impl PyEvaluator {
968 #[getter]
976 fn parameters(&self) -> PyParameterMap {
977 PyParameterMap(self.0.parameters())
978 }
979 #[getter]
981 fn n_free(&self) -> usize {
982 self.0.n_free()
983 }
984 #[getter]
986 fn n_fixed(&self) -> usize {
987 self.0.n_fixed()
988 }
989 #[getter]
991 fn n_parameters(&self) -> usize {
992 self.0.n_parameters()
993 }
994 fn fix_parameter(&self, name: &str, value: f64) -> PyResult<()> {
996 Ok(self.0.fix_parameter(name, value)?)
997 }
998 fn free_parameter(&self, name: &str) -> PyResult<()> {
1000 Ok(self.0.free_parameter(name)?)
1001 }
1002 fn rename_parameter(&self, old: &str, new: &str) -> PyResult<()> {
1004 Ok(self.0.rename_parameter(old, new)?)
1005 }
1006 fn rename_parameters(&self, mapping: HashMap<String, String>) -> PyResult<()> {
1008 Ok(self.0.rename_parameters(&mapping)?)
1009 }
1010 #[pyo3(signature = (arg, *, strict=true))]
1027 fn activate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1028 if let Ok(string_arg) = arg.extract::<String>() {
1029 if strict {
1030 self.0.activate_strict(&string_arg)?;
1031 } else {
1032 self.0.activate(&string_arg);
1033 }
1034 } else if let Ok(list_arg) = arg.cast::<PyList>() {
1035 let vec: Vec<String> = list_arg.extract()?;
1036 if strict {
1037 self.0.activate_many_strict(&vec)?;
1038 } else {
1039 self.0.activate_many(&vec);
1040 }
1041 } else {
1042 return Err(PyTypeError::new_err(
1043 "Argument must be either a string or a list of strings",
1044 ));
1045 }
1046 Ok(())
1047 }
1048 fn activate_all(&self) {
1051 self.0.activate_all();
1052 }
1053 #[pyo3(signature = (arg, *, strict=true))]
1072 fn deactivate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1073 if let Ok(string_arg) = arg.extract::<String>() {
1074 if strict {
1075 self.0.deactivate_strict(&string_arg)?;
1076 } else {
1077 self.0.deactivate(&string_arg);
1078 }
1079 } else if let Ok(list_arg) = arg.cast::<PyList>() {
1080 let vec: Vec<String> = list_arg.extract()?;
1081 if strict {
1082 self.0.deactivate_many_strict(&vec)?;
1083 } else {
1084 self.0.deactivate_many(&vec);
1085 }
1086 } else {
1087 return Err(PyTypeError::new_err(
1088 "Argument must be either a string or a list of strings",
1089 ));
1090 }
1091 Ok(())
1092 }
1093 fn deactivate_all(&self) {
1096 self.0.deactivate_all();
1097 }
1098 #[pyo3(signature = (arg, *, strict=true))]
1117 fn isolate(&self, arg: &Bound<'_, PyAny>, strict: bool) -> PyResult<()> {
1118 if let Ok(string_arg) = arg.extract::<String>() {
1119 if strict {
1120 self.0.isolate_strict(&string_arg)?;
1121 } else {
1122 self.0.isolate(&string_arg);
1123 }
1124 } else if let Ok(list_arg) = arg.cast::<PyList>() {
1125 let vec: Vec<String> = list_arg.extract()?;
1126 if strict {
1127 self.0.isolate_many_strict(&vec)?;
1128 } else {
1129 self.0.isolate_many(&vec);
1130 }
1131 } else {
1132 return Err(PyTypeError::new_err(
1133 "Argument must be either a string or a list of strings",
1134 ));
1135 }
1136 Ok(())
1137 }
1138
1139 #[getter]
1141 fn active_mask(&self) -> Vec<bool> {
1142 self.0.active_mask()
1143 }
1144
1145 fn set_active_mask(&self, mask: Vec<bool>) -> PyResult<()> {
1147 self.0.set_active_mask(&mask)?;
1148 Ok(())
1149 }
1150
1151 #[getter]
1153 fn compiled_expression(&self) -> PyCompiledExpression {
1154 PyCompiledExpression(self.0.compiled_expression())
1155 }
1156
1157 #[getter]
1159 fn expression(&self) -> PyExpression {
1160 PyExpression(self.0.expression())
1161 }
1162
1163 #[pyo3(signature = (parameters, *, threads=None))]
1185 fn evaluate<'py>(
1186 &self,
1187 py: Python<'py>,
1188 parameters: Vec<f64>,
1189 threads: Option<usize>,
1190 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
1191 let values = install_with_threads(threads, || self.0.evaluate(¶meters))?;
1192 Ok(PyArray1::from_slice(py, &values?))
1193 }
1194 #[pyo3(signature = (parameters, indices, *, threads=None))]
1218 fn evaluate_batch<'py>(
1219 &self,
1220 py: Python<'py>,
1221 parameters: Vec<f64>,
1222 indices: Vec<usize>,
1223 threads: Option<usize>,
1224 ) -> PyResult<Bound<'py, PyArray1<Complex64>>> {
1225 let values =
1226 install_with_threads(threads, || self.0.evaluate_batch(¶meters, &indices))?;
1227 Ok(PyArray1::from_slice(py, &values?))
1228 }
1229 #[pyo3(signature = (parameters, *, threads=None))]
1252 fn evaluate_gradient<'py>(
1253 &self,
1254 py: Python<'py>,
1255 parameters: Vec<f64>,
1256 threads: Option<usize>,
1257 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
1258 let gradients: LadduResult<_> = install_with_threads(threads, || {
1259 Ok(self
1260 .0
1261 .evaluate_gradient(¶meters)?
1262 .iter()
1263 .map(|grad| grad.data.as_vec().to_vec())
1264 .collect::<Vec<Vec<Complex64>>>())
1265 })?;
1266 Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
1267 }
1268 #[pyo3(signature = (parameters, indices, *, threads=None))]
1293 fn evaluate_gradient_batch<'py>(
1294 &self,
1295 py: Python<'py>,
1296 parameters: Vec<f64>,
1297 indices: Vec<usize>,
1298 threads: Option<usize>,
1299 ) -> PyResult<Bound<'py, PyArray2<Complex64>>> {
1300 let gradients: LadduResult<_> = install_with_threads(threads, || {
1301 Ok(self
1302 .0
1303 .evaluate_gradient_batch(¶meters, &indices)?
1304 .iter()
1305 .map(|grad| grad.data.as_vec().to_vec())
1306 .collect::<Vec<Vec<Complex64>>>())
1307 })?;
1308 Ok(PyArray2::from_vec2(py, &gradients?).map_err(LadduError::NumpyError)?)
1309 }
1310}
1311
1312#[pyclass(name = "CompiledExpression", module = "laddu", from_py_object)]
1319#[derive(Clone)]
1320pub struct PyCompiledExpression(pub CompiledExpression);
1321
1322#[pymethods]
1323impl PyCompiledExpression {
1324 fn __str__(&self) -> String {
1325 format!("{}", self.0)
1326 }
1327 fn __repr__(&self) -> String {
1328 format!("{:?}", self.0)
1329 }
1330}
1331
1332#[pyclass(name = "ParameterMap", module = "laddu", from_py_object)]
1333#[derive(Clone)]
1334pub struct PyParameterMap(pub ParameterMap);
1335
1336#[pymethods]
1337impl PyParameterMap {
1338 #[getter]
1339 fn names<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
1340 PyTuple::new(py, self.0.names())
1341 }
1342
1343 #[getter]
1344 fn free(&self) -> PyParameterMap {
1345 PyParameterMap(self.0.free())
1346 }
1347
1348 #[getter]
1349 fn fixed(&self) -> PyParameterMap {
1350 PyParameterMap(self.0.fixed())
1351 }
1352
1353 fn contains(&self, name: &str) -> bool {
1354 self.0.contains_key(name)
1355 }
1356
1357 fn __contains__(&self, name: &str) -> bool {
1358 self.contains(name)
1359 }
1360
1361 fn __len__(&self) -> usize {
1362 self.0.len()
1363 }
1364
1365 fn __bool__(&self) -> bool {
1366 !self.0.is_empty()
1367 }
1368
1369 fn __getitem__(&self, index: &Bound<'_, PyAny>) -> PyResult<PyParameter> {
1370 if let Ok(name) = index.extract::<String>() {
1371 return self
1372 .0
1373 .get(&name)
1374 .cloned()
1375 .map(PyParameter)
1376 .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err(name));
1377 }
1378 if let Ok(index) = index.extract::<isize>() {
1379 let len = self.0.len() as isize;
1380 let normalized = if index < 0 { len + index } else { index };
1381 if normalized < 0 || normalized >= len {
1382 return Err(pyo3::exceptions::PyIndexError::new_err(format!(
1383 "parameter index out of range: {index}"
1384 )));
1385 }
1386 return Ok(PyParameter(self.0[normalized as usize].clone()));
1387 }
1388 Err(pyo3::exceptions::PyTypeError::new_err(
1389 "ParameterMap indices must be str or int",
1390 ))
1391 }
1392
1393 fn __str__(&self) -> String {
1394 self.0.to_string()
1395 }
1396
1397 fn __repr__(&self) -> String {
1398 format!("{:?}", self.0)
1399 }
1400
1401 fn index(&self, name: &str) -> PyResult<usize> {
1402 self.0
1403 .index(name)
1404 .ok_or_else(|| PyValueError::new_err(format!("parameter not found: {name}")))
1405 }
1406
1407 fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyIterator>> {
1408 let params: Vec<PyParameter> = self.0.clone().into_iter().map(PyParameter).collect();
1409 let list = PyList::new(py, params)?;
1410 PyIterator::from_object(&list)
1411 }
1412}
1413
1414#[pyclass(name = "Parameter", module = "laddu", from_py_object)]
1415#[derive(Clone)]
1416pub struct PyParameter(pub Parameter);
1417
1418#[pymethods]
1419impl PyParameter {
1420 #[getter]
1421 fn name(&self) -> String {
1422 self.0.name()
1423 }
1424 #[getter]
1425 fn fixed(&self) -> Option<f64> {
1426 self.0.fixed()
1427 }
1428 #[getter]
1429 fn initial(&self) -> Option<f64> {
1430 self.0.initial()
1431 }
1432 #[getter]
1433 fn bounds(&self) -> (Option<f64>, Option<f64>) {
1434 self.0.bounds()
1435 }
1436 #[getter]
1437 fn unit(&self) -> Option<String> {
1438 self.0.unit()
1439 }
1440 #[getter]
1441 fn latex(&self) -> Option<String> {
1442 self.0.latex()
1443 }
1444 #[getter]
1445 fn description(&self) -> Option<String> {
1446 self.0.description()
1447 }
1448}
1449
1450#[pyfunction(name = "parameter", signature = (name, fixed=None, *, initial=None, bounds=(None, None), unit=None, latex=None, description=None))]
1482pub fn py_parameter(
1483 name: &str,
1484 fixed: Option<f64>,
1485 initial: Option<f64>,
1486 bounds: (Option<f64>, Option<f64>),
1487 unit: Option<&str>,
1488 latex: Option<&str>,
1489 description: Option<&str>,
1490) -> PyParameter {
1491 let par = Parameter::new(name);
1492 if let Some(value) = initial {
1493 par.set_initial(value);
1494 }
1495 if let Some(value) = fixed {
1496 par.set_fixed_value(Some(value)); }
1498 par.set_bounds(bounds.0, bounds.1);
1499 if let Some(unit) = unit {
1500 par.set_unit(unit);
1501 }
1502 if let Some(latex) = latex {
1503 par.set_latex(latex);
1504 }
1505 if let Some(description) = description {
1506 par.set_description(description);
1507 }
1508 PyParameter(par)
1509}
1510
1511#[pyfunction(name = "TestAmplitude", signature = (*tags, re, im))]
1513pub fn py_test_amplitude(
1514 tags: &Bound<'_, PyTuple>,
1515 re: PyParameter,
1516 im: PyParameter,
1517) -> PyResult<PyExpression> {
1518 Ok(PyExpression(TestAmplitude::new(
1519 py_tags(tags)?,
1520 re.0,
1521 im.0,
1522 )?))
1523}