laddu_amplitudes/
common.rs

1use laddu_core::{
2    amplitudes::{Amplitude, AmplitudeID, ParameterLike},
3    data::Event,
4    resources::{Cache, ParameterID, Parameters, Resources},
5    Complex, DVector, Float, LadduError,
6};
7use serde::{Deserialize, Serialize};
8
9#[cfg(feature = "python")]
10use laddu_python::amplitudes::{PyAmplitude, PyParameterLike};
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13
14/// A scalar-valued [`Amplitude`] which just contains a single parameter as its value.
15#[derive(Clone, Serialize, Deserialize)]
16pub struct Scalar {
17    name: String,
18    value: ParameterLike,
19    pid: ParameterID,
20}
21
22impl Scalar {
23    /// Create a new [`Scalar`] with the given name and parameter value.
24    pub fn new(name: &str, value: ParameterLike) -> Box<Self> {
25        Self {
26            name: name.to_string(),
27            value,
28            pid: Default::default(),
29        }
30        .into()
31    }
32}
33
34#[typetag::serde]
35impl Amplitude for Scalar {
36    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
37        self.pid = resources.register_parameter(&self.value);
38        resources.register_amplitude(&self.name)
39    }
40
41    fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
42        Complex::new(parameters.get(self.pid), 0.0)
43    }
44
45    fn compute_gradient(
46        &self,
47        _parameters: &Parameters,
48        _event: &Event,
49        _cache: &Cache,
50        gradient: &mut DVector<Complex<Float>>,
51    ) {
52        if let ParameterID::Parameter(ind) = self.pid {
53            gradient[ind] = Complex::ONE;
54        }
55    }
56}
57
58/// An Amplitude which represents a single scalar value
59///
60/// Parameters
61/// ----------
62/// name : str
63///     The Amplitude name
64/// value : laddu.ParameterLike
65///     The scalar parameter contained in the Amplitude
66///
67/// Returns
68/// -------
69/// laddu.Amplitude
70///     An Amplitude which can be registered by a laddu.Manager
71///
72/// See Also
73/// --------
74/// laddu.Manager
75///
76#[cfg(feature = "python")]
77#[pyfunction(name = "Scalar")]
78pub fn py_scalar(name: &str, value: PyParameterLike) -> PyAmplitude {
79    PyAmplitude(Scalar::new(name, value.0))
80}
81
82/// A complex-valued [`Amplitude`] which just contains two parameters representing its real and
83/// imaginary parts.
84#[derive(Clone, Serialize, Deserialize)]
85pub struct ComplexScalar {
86    name: String,
87    re: ParameterLike,
88    pid_re: ParameterID,
89    im: ParameterLike,
90    pid_im: ParameterID,
91}
92
93impl ComplexScalar {
94    /// Create a new [`ComplexScalar`] with the given name, real, and imaginary part.
95    pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
96        Self {
97            name: name.to_string(),
98            re,
99            pid_re: Default::default(),
100            im,
101            pid_im: Default::default(),
102        }
103        .into()
104    }
105}
106
107#[typetag::serde]
108impl Amplitude for ComplexScalar {
109    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
110        self.pid_re = resources.register_parameter(&self.re);
111        self.pid_im = resources.register_parameter(&self.im);
112        resources.register_amplitude(&self.name)
113    }
114
115    fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
116        Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
117    }
118
119    fn compute_gradient(
120        &self,
121        _parameters: &Parameters,
122        _event: &Event,
123        _cache: &Cache,
124        gradient: &mut DVector<Complex<Float>>,
125    ) {
126        if let ParameterID::Parameter(ind) = self.pid_re {
127            gradient[ind] = Complex::ONE;
128        }
129        if let ParameterID::Parameter(ind) = self.pid_im {
130            gradient[ind] = Complex::I;
131        }
132    }
133}
134
135/// An Amplitude which represents a complex value
136///
137/// Parameters
138/// ----------
139/// name : str
140///     The Amplitude name
141/// re: laddu.ParameterLike
142///     The real part of the complex value contained in the Amplitude
143/// im: laddu.ParameterLike
144///     The imaginary part of the complex value contained in the Amplitude
145///
146/// Returns
147/// -------
148/// laddu.Amplitude
149///     An Amplitude which can be registered by a laddu.Manager
150///
151/// See Also
152/// --------
153/// laddu.Manager
154///
155#[cfg(feature = "python")]
156#[pyfunction(name = "ComplexScalar")]
157pub fn py_complex_scalar(name: &str, re: PyParameterLike, im: PyParameterLike) -> PyAmplitude {
158    PyAmplitude(ComplexScalar::new(name, re.0, im.0))
159}
160
161/// A complex-valued [`Amplitude`] which just contains two parameters representing its magnitude and
162/// phase.
163#[derive(Clone, Serialize, Deserialize)]
164pub struct PolarComplexScalar {
165    name: String,
166    r: ParameterLike,
167    pid_r: ParameterID,
168    theta: ParameterLike,
169    pid_theta: ParameterID,
170}
171
172impl PolarComplexScalar {
173    /// Create a new [`PolarComplexScalar`] with the given name, magnitude (`r`), and phase (`theta`).
174    pub fn new(name: &str, r: ParameterLike, theta: ParameterLike) -> Box<Self> {
175        Self {
176            name: name.to_string(),
177            r,
178            pid_r: Default::default(),
179            theta,
180            pid_theta: Default::default(),
181        }
182        .into()
183    }
184}
185
186#[typetag::serde]
187impl Amplitude for PolarComplexScalar {
188    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
189        self.pid_r = resources.register_parameter(&self.r);
190        self.pid_theta = resources.register_parameter(&self.theta);
191        resources.register_amplitude(&self.name)
192    }
193
194    fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
195        Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta))
196    }
197
198    fn compute_gradient(
199        &self,
200        parameters: &Parameters,
201        _event: &Event,
202        _cache: &Cache,
203        gradient: &mut DVector<Complex<Float>>,
204    ) {
205        let exp_i_theta = Complex::cis(parameters.get(self.pid_theta));
206        if let ParameterID::Parameter(ind) = self.pid_r {
207            gradient[ind] = exp_i_theta;
208        }
209        if let ParameterID::Parameter(ind) = self.pid_theta {
210            gradient[ind] = Complex::<Float>::I
211                * Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta));
212        }
213    }
214}
215
216/// An Amplitude which represents a complex scalar value in polar form
217///
218/// Parameters
219/// ----------
220/// name : str
221///     The Amplitude name
222/// r: laddu.ParameterLike
223///     The magnitude of the complex value contained in the Amplitude
224/// theta: laddu.ParameterLike
225///     The argument of the complex value contained in the Amplitude
226///
227/// Returns
228/// -------
229/// laddu.Amplitude
230///     An Amplitude which can be registered by a laddu.Manager
231///
232/// See Also
233/// --------
234/// laddu.Manager
235///
236#[cfg(feature = "python")]
237#[pyfunction(name = "PolarComplexScalar")]
238pub fn py_polar_complex_scalar(
239    name: &str,
240    r: PyParameterLike,
241    theta: PyParameterLike,
242) -> PyAmplitude {
243    PyAmplitude(PolarComplexScalar::new(name, r.0, theta.0))
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use approx::assert_relative_eq;
250    use laddu_core::{data::test_dataset, parameter, Manager, PI};
251    use std::sync::Arc;
252
253    #[test]
254    fn test_scalar_creation_and_evaluation() {
255        let mut manager = Manager::default();
256        let amp = Scalar::new("test_scalar", parameter("test_param"));
257        let aid = manager.register(amp).unwrap();
258
259        let dataset = Arc::new(test_dataset());
260        let expr = aid.into(); // Direct amplitude evaluation
261        let model = manager.model(&expr);
262        let evaluator = model.load(&dataset);
263
264        let params = vec![2.5];
265        let result = evaluator.evaluate(&params);
266
267        assert_relative_eq!(result[0].re, 2.5);
268        assert_relative_eq!(result[0].im, 0.0);
269    }
270
271    #[test]
272    fn test_scalar_gradient() {
273        let mut manager = Manager::default();
274        let amp = Scalar::new("test_scalar", parameter("test_param"));
275        let aid = manager.register(amp).unwrap();
276
277        let dataset = Arc::new(test_dataset());
278        let expr = aid.norm_sqr(); // |f(x)|^2
279        let model = manager.model(&expr);
280        let evaluator = model.load(&dataset);
281
282        let params = vec![2.0];
283        let gradient = evaluator.evaluate_gradient(&params);
284
285        // For |f(x)|^2 where f(x) = x, the derivative should be 2x
286        assert_relative_eq!(gradient[0][0].re, 4.0);
287        assert_relative_eq!(gradient[0][0].im, 0.0);
288    }
289
290    #[test]
291    fn test_complex_scalar_evaluation() {
292        let mut manager = Manager::default();
293        let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
294        let aid = manager.register(amp).unwrap();
295
296        let dataset = Arc::new(test_dataset());
297        let expr = aid.into();
298        let model = manager.model(&expr);
299        let evaluator = model.load(&dataset);
300
301        let params = vec![1.5, 2.5]; // Real and imaginary parts
302        let result = evaluator.evaluate(&params);
303
304        assert_relative_eq!(result[0].re, 1.5);
305        assert_relative_eq!(result[0].im, 2.5);
306    }
307
308    #[test]
309    fn test_complex_scalar_gradient() {
310        let mut manager = Manager::default();
311        let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
312        let aid = manager.register(amp).unwrap();
313
314        let dataset = Arc::new(test_dataset());
315        let expr = aid.norm_sqr(); // |f(x + iy)|^2
316        let model = manager.model(&expr);
317        let evaluator = model.load(&dataset);
318
319        let params = vec![3.0, 4.0]; // Real and imaginary parts
320        let gradient = evaluator.evaluate_gradient(&params);
321
322        // For |f(x + iy)|^2, partial derivatives should be 2x and 2y
323        assert_relative_eq!(gradient[0][0].re, 6.0);
324        assert_relative_eq!(gradient[0][0].im, 0.0);
325        assert_relative_eq!(gradient[0][1].re, 8.0);
326        assert_relative_eq!(gradient[0][1].im, 0.0);
327    }
328
329    #[test]
330    fn test_polar_complex_scalar_evaluation() {
331        let mut manager = Manager::default();
332        let amp =
333            PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
334        let aid = manager.register(amp).unwrap();
335
336        let dataset = Arc::new(test_dataset());
337        let expr = aid.into();
338        let model = manager.model(&expr);
339        let evaluator = model.load(&dataset);
340
341        let r = 2.0;
342        let theta = PI / 4.3;
343        let params = vec![r, theta];
344        let result = evaluator.evaluate(&params);
345
346        // r * (cos(theta) + i*sin(theta))
347        assert_relative_eq!(result[0].re, r * theta.cos());
348        assert_relative_eq!(result[0].im, r * theta.sin());
349    }
350
351    #[test]
352    fn test_polar_complex_scalar_gradient() {
353        let mut manager = Manager::default();
354        let amp =
355            PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
356        let aid = manager.register(amp).unwrap();
357
358        let dataset = Arc::new(test_dataset());
359        let expr = aid.into(); // f(r,θ) = re^(iθ)
360        let model = manager.model(&expr);
361        let evaluator = model.load(&dataset);
362
363        let r = 2.0;
364        let theta = PI / 4.3;
365        let params = vec![r, theta];
366        let gradient = evaluator.evaluate_gradient(&params);
367
368        // d/dr re^(iθ) = e^(iθ), d/dθ re^(iθ) = ire^(iθ)
369        assert_relative_eq!(gradient[0][0].re, Float::cos(theta));
370        assert_relative_eq!(gradient[0][0].im, Float::sin(theta));
371        assert_relative_eq!(gradient[0][1].re, -r * Float::sin(theta));
372        assert_relative_eq!(gradient[0][1].im, r * Float::cos(theta));
373    }
374}