1use serde::{Deserialize, Serialize};
2
3use laddu_core::{
4 amplitudes::{Amplitude, AmplitudeID},
5 data::Event,
6 resources::{Cache, ComplexScalarID, Parameters, Resources},
7 utils::{
8 functions::spherical_harmonic,
9 variables::{Angles, Variable},
10 },
11 Complex, DVector, Float, LadduError, Polarization, Sign,
12};
13
14#[cfg(feature = "python")]
15use laddu_python::{
16 amplitudes::PyAmplitude,
17 utils::variables::{PyAngles, PyPolarization},
18};
19#[cfg(feature = "python")]
20use pyo3::prelude::*;
21
22#[derive(Clone, Serialize, Deserialize)]
28pub struct Zlm {
29 name: String,
30 l: usize,
31 m: isize,
32 r: Sign,
33 angles: Angles,
34 polarization: Polarization,
35 csid: ComplexScalarID,
36}
37
38impl Zlm {
39 pub fn new(
42 name: &str,
43 l: usize,
44 m: isize,
45 r: Sign,
46 angles: &Angles,
47 polarization: &Polarization,
48 ) -> Box<Self> {
49 Self {
50 name: name.to_string(),
51 l,
52 m,
53 r,
54 angles: angles.clone(),
55 polarization: polarization.clone(),
56 csid: ComplexScalarID::default(),
57 }
58 .into()
59 }
60}
61
62#[typetag::serde]
63impl Amplitude for Zlm {
64 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
65 self.csid = resources.register_complex_scalar(None);
66 resources.register_amplitude(&self.name)
67 }
68
69 fn precompute(&self, event: &Event, cache: &mut Cache) {
70 let ylm = spherical_harmonic(
71 self.l,
72 self.m,
73 self.angles.costheta.value(event),
74 self.angles.phi.value(event),
75 );
76 let pol_angle = self.polarization.pol_angle.value(event);
77 let pgamma = self.polarization.pol_magnitude.value(event);
78 let phase = Complex::new(Float::cos(-pol_angle), Float::sin(-pol_angle));
79 let zlm = ylm * phase;
80 cache.store_complex_scalar(
81 self.csid,
82 match self.r {
83 Sign::Positive => Complex::new(
84 Float::sqrt(1.0 + pgamma) * zlm.re,
85 Float::sqrt(1.0 - pgamma) * zlm.im,
86 ),
87 Sign::Negative => Complex::new(
88 Float::sqrt(1.0 - pgamma) * zlm.re,
89 Float::sqrt(1.0 + pgamma) * zlm.im,
90 ),
91 },
92 );
93 }
94
95 fn compute(&self, _parameters: &Parameters, _event: &Event, cache: &Cache) -> Complex<Float> {
96 cache.get_complex_scalar(self.csid)
97 }
98
99 fn compute_gradient(
100 &self,
101 _parameters: &Parameters,
102 _event: &Event,
103 _cache: &Cache,
104 _gradient: &mut DVector<Complex<Float>>,
105 ) {
106 }
108}
109
110#[cfg(feature = "python")]
151#[pyfunction(name = "Zlm")]
152pub fn py_zlm(
153 name: &str,
154 l: usize,
155 m: isize,
156 r: &str,
157 angles: &PyAngles,
158 polarization: &PyPolarization,
159) -> PyResult<PyAmplitude> {
160 Ok(PyAmplitude(Zlm::new(
161 name,
162 l,
163 m,
164 r.parse()?,
165 &angles.0,
166 &polarization.0,
167 )))
168}
169
170#[cfg(test)]
171mod tests {
172 use std::sync::Arc;
173
174 use super::*;
175 use approx::assert_relative_eq;
176 use laddu_core::{data::test_dataset, Frame, Manager};
177
178 #[test]
179 fn test_zlm_evaluation() {
180 let mut manager = Manager::default();
181 let angles = Angles::new(0, [1], [2], [2, 3], Frame::Helicity);
182 let polarization = Polarization::new(0, [1]);
183 let amp = Zlm::new("zlm", 1, 1, Sign::Positive, &angles, &polarization);
184 let aid = manager.register(amp).unwrap();
185
186 let dataset = Arc::new(test_dataset());
187 let expr = aid.into();
188 let model = manager.model(&expr);
189 let evaluator = model.load(&dataset);
190
191 let result = evaluator.evaluate(&[]);
192
193 assert_relative_eq!(result[0].re, 0.04284127, epsilon = Float::EPSILON.sqrt());
194 assert_relative_eq!(result[0].im, -0.23859638, epsilon = Float::EPSILON.sqrt());
195 }
196
197 #[test]
198 fn test_zlm_gradient() {
199 let mut manager = Manager::default();
200 let angles = Angles::new(0, [1], [2], [2, 3], Frame::Helicity);
201 let polarization = Polarization::new(0, [1]);
202 let amp = Zlm::new("zlm", 1, 1, Sign::Positive, &angles, &polarization);
203 let aid = manager.register(amp).unwrap();
204
205 let dataset = Arc::new(test_dataset());
206 let expr = aid.into();
207 let model = manager.model(&expr);
208 let evaluator = model.load(&dataset);
209
210 let result = evaluator.evaluate_gradient(&[]);
211 assert_eq!(result[0].len(), 0); }
213}