Skip to main content

laddu_python/
quantum.rs

1pub mod angular_momentum {
2    use laddu_core::{
3        allowed_projections, helicity_combinations, AngularMomentum, AngularMomentumProjection,
4        LadduError, LadduResult, OrbitalAngularMomentum,
5    };
6    use num::rational::Ratio;
7    use pyo3::{
8        prelude::*,
9        types::{PyAny, PyBool, PyModule},
10        IntoPyObjectExt,
11    };
12    type PyQuantumNumber = Py<PyAny>;
13    type PyHelicityCombination = (PyQuantumNumber, PyQuantumNumber, PyQuantumNumber);
14
15    pub fn parse_angular_momentum(input: &Bound<'_, PyAny>) -> PyResult<AngularMomentum> {
16        Ok(parse_ratio_like(input).and_then(AngularMomentum::from_ratio)?)
17    }
18
19    fn parse_ratio_like(input: &Bound<'_, PyAny>) -> LadduResult<Ratio<i32>> {
20        if input.is_instance_of::<PyBool>() {
21            return Err(LadduError::Custom(
22                "quantum number cannot be a bool".to_string(),
23            ));
24        }
25        if let Ok(value) = input.extract::<i32>() {
26            return Ok(Ratio::from_integer(value));
27        }
28        if let Ok(value) = input.extract::<f64>() {
29            let twice = AngularMomentumProjection::from_f64(value)?.value();
30            return Ok(Ratio::new(twice, 2));
31        }
32        let numerator = input
33            .getattr("numerator")
34            .and_then(|value| value.extract::<i32>());
35        let denominator = input
36            .getattr("denominator")
37            .and_then(|value| value.extract::<i32>());
38        if let (Ok(numerator), Ok(denominator)) = (numerator, denominator) {
39            if denominator == 0 {
40                return Err(LadduError::Custom(
41                    "quantum number denominator cannot be zero".to_string(),
42                ));
43            }
44            return Ok(Ratio::new(numerator, denominator));
45        }
46        Err(LadduError::Custom(
47            "quantum number must be an int, float, or fractions.Fraction".to_string(),
48        ))
49    }
50
51    pub fn parse_projection(input: &Bound<'_, PyAny>) -> PyResult<AngularMomentumProjection> {
52        Ok(parse_ratio_like(input).and_then(AngularMomentumProjection::from_ratio)?)
53    }
54
55    pub fn parse_orbital_angular_momentum(
56        input: &Bound<'_, PyAny>,
57    ) -> PyResult<OrbitalAngularMomentum> {
58        Ok(parse_ratio_like(input).and_then(OrbitalAngularMomentum::from_ratio)?)
59    }
60
61    pub fn projection_to_python(
62        py: Python<'_>,
63        projection: AngularMomentumProjection,
64    ) -> PyResult<PyQuantumNumber> {
65        let twice = projection.value();
66        if twice % 2 == 0 {
67            Ok((twice / 2).into_bound_py_any(py)?.unbind())
68        } else {
69            let fractions = PyModule::import(py, "fractions")?;
70            let fraction = fractions.getattr("Fraction")?;
71            Ok(fraction.call1((twice, 2))?.unbind())
72        }
73    }
74
75    /// Enumerate allowed spin projections.
76    #[pyfunction(name = "allowed_projections")]
77    pub fn py_allowed_projections(
78        py: Python<'_>,
79        spin: &Bound<'_, PyAny>,
80    ) -> PyResult<Vec<PyQuantumNumber>> {
81        allowed_projections(parse_angular_momentum(spin)?)
82            .into_iter()
83            .map(|projection| projection_to_python(py, projection))
84            .collect()
85    }
86
87    /// Enumerate daughter helicity combinations.
88    #[pyfunction(name = "helicity_combinations")]
89    pub fn py_helicity_combinations(
90        py: Python<'_>,
91        spin_1: &Bound<'_, PyAny>,
92        spin_2: &Bound<'_, PyAny>,
93    ) -> PyResult<Vec<PyHelicityCombination>> {
94        helicity_combinations(
95            parse_angular_momentum(spin_1)?,
96            parse_angular_momentum(spin_2)?,
97        )
98        .into_iter()
99        .map(|combination| {
100            Ok((
101                projection_to_python(py, combination.lambda_1())?,
102                projection_to_python(py, combination.lambda_2())?,
103                projection_to_python(py, combination.helicity())?,
104            ))
105        })
106        .collect()
107    }
108}