1pub mod angular_momentum {
2 use laddu_core::{
3 allowed_projections, helicity_combinations, AngularMomentum, AngularMomentumProjection,
4 LadduError, LadduResult, OrbitalAngularMomentum,
5 };
6 use num::rational::Ratio;
7 use pyo3::{
8 prelude::*,
9 types::{PyAny, PyBool, PyModule},
10 IntoPyObjectExt,
11 };
12 type PyQuantumNumber = Py<PyAny>;
13 type PyHelicityCombination = (PyQuantumNumber, PyQuantumNumber, PyQuantumNumber);
14
15 pub fn parse_angular_momentum(input: &Bound<'_, PyAny>) -> PyResult<AngularMomentum> {
16 Ok(parse_ratio_like(input).and_then(AngularMomentum::from_ratio)?)
17 }
18
19 fn parse_ratio_like(input: &Bound<'_, PyAny>) -> LadduResult<Ratio<i32>> {
20 if input.is_instance_of::<PyBool>() {
21 return Err(LadduError::Custom(
22 "quantum number cannot be a bool".to_string(),
23 ));
24 }
25 if let Ok(value) = input.extract::<i32>() {
26 return Ok(Ratio::from_integer(value));
27 }
28 if let Ok(value) = input.extract::<f64>() {
29 let twice = AngularMomentumProjection::from_f64(value)?.value();
30 return Ok(Ratio::new(twice, 2));
31 }
32 let numerator = input
33 .getattr("numerator")
34 .and_then(|value| value.extract::<i32>());
35 let denominator = input
36 .getattr("denominator")
37 .and_then(|value| value.extract::<i32>());
38 if let (Ok(numerator), Ok(denominator)) = (numerator, denominator) {
39 if denominator == 0 {
40 return Err(LadduError::Custom(
41 "quantum number denominator cannot be zero".to_string(),
42 ));
43 }
44 return Ok(Ratio::new(numerator, denominator));
45 }
46 Err(LadduError::Custom(
47 "quantum number must be an int, float, or fractions.Fraction".to_string(),
48 ))
49 }
50
51 pub fn parse_projection(input: &Bound<'_, PyAny>) -> PyResult<AngularMomentumProjection> {
52 Ok(parse_ratio_like(input).and_then(AngularMomentumProjection::from_ratio)?)
53 }
54
55 pub fn parse_orbital_angular_momentum(
56 input: &Bound<'_, PyAny>,
57 ) -> PyResult<OrbitalAngularMomentum> {
58 Ok(parse_ratio_like(input).and_then(OrbitalAngularMomentum::from_ratio)?)
59 }
60
61 pub fn projection_to_python(
62 py: Python<'_>,
63 projection: AngularMomentumProjection,
64 ) -> PyResult<PyQuantumNumber> {
65 let twice = projection.value();
66 if twice % 2 == 0 {
67 Ok((twice / 2).into_bound_py_any(py)?.unbind())
68 } else {
69 let fractions = PyModule::import(py, "fractions")?;
70 let fraction = fractions.getattr("Fraction")?;
71 Ok(fraction.call1((twice, 2))?.unbind())
72 }
73 }
74
75 #[pyfunction(name = "allowed_projections")]
77 pub fn py_allowed_projections(
78 py: Python<'_>,
79 spin: &Bound<'_, PyAny>,
80 ) -> PyResult<Vec<PyQuantumNumber>> {
81 allowed_projections(parse_angular_momentum(spin)?)
82 .into_iter()
83 .map(|projection| projection_to_python(py, projection))
84 .collect()
85 }
86
87 #[pyfunction(name = "helicity_combinations")]
89 pub fn py_helicity_combinations(
90 py: Python<'_>,
91 spin_1: &Bound<'_, PyAny>,
92 spin_2: &Bound<'_, PyAny>,
93 ) -> PyResult<Vec<PyHelicityCombination>> {
94 helicity_combinations(
95 parse_angular_momentum(spin_1)?,
96 parse_angular_momentum(spin_2)?,
97 )
98 .into_iter()
99 .map(|combination| {
100 Ok((
101 projection_to_python(py, combination.lambda_1())?,
102 projection_to_python(py, combination.lambda_2())?,
103 projection_to_python(py, combination.helicity())?,
104 ))
105 })
106 .collect()
107 }
108}