laddu_amplitudes/
common.rs

1use laddu_core::{
2    amplitudes::{Amplitude, AmplitudeID, ParameterLike},
3    data::Event,
4    resources::{Cache, ParameterID, Parameters, Resources},
5    Float, LadduError,
6};
7#[cfg(feature = "python")]
8use laddu_python::amplitudes::{PyAmplitude, PyParameterLike};
9use nalgebra::DVector;
10use num::Complex;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13use serde::{Deserialize, Serialize};
14
15/// A scalar-valued [`Amplitude`] which just contains a single parameter as its value.
16#[derive(Clone, Serialize, Deserialize)]
17pub struct Scalar {
18    name: String,
19    value: ParameterLike,
20    pid: ParameterID,
21}
22
23impl Scalar {
24    /// Create a new [`Scalar`] with the given name and parameter value.
25    pub fn new(name: &str, value: ParameterLike) -> Box<Self> {
26        Self {
27            name: name.to_string(),
28            value,
29            pid: Default::default(),
30        }
31        .into()
32    }
33}
34
35#[typetag::serde]
36impl Amplitude for Scalar {
37    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
38        self.pid = resources.register_parameter(&self.value);
39        resources.register_amplitude(&self.name)
40    }
41
42    fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
43        Complex::new(parameters.get(self.pid), 0.0)
44    }
45
46    fn compute_gradient(
47        &self,
48        _parameters: &Parameters,
49        _event: &Event,
50        _cache: &Cache,
51        gradient: &mut DVector<Complex<Float>>,
52    ) {
53        if let ParameterID::Parameter(ind) = self.pid {
54            gradient[ind] = Complex::ONE;
55        }
56    }
57}
58
59/// An Amplitude which represents a single scalar value
60///
61/// Parameters
62/// ----------
63/// name : str
64///     The Amplitude name
65/// value : laddu.ParameterLike
66///     The scalar parameter contained in the Amplitude
67///
68/// Returns
69/// -------
70/// laddu.Amplitude
71///     An Amplitude which can be registered by a laddu.Manager
72///
73/// See Also
74/// --------
75/// laddu.Manager
76///
77#[cfg(feature = "python")]
78#[pyfunction(name = "Scalar")]
79pub fn py_scalar(name: &str, value: PyParameterLike) -> PyAmplitude {
80    PyAmplitude(Scalar::new(name, value.0))
81}
82
83/// A complex-valued [`Amplitude`] which just contains two parameters representing its real and
84/// imaginary parts.
85#[derive(Clone, Serialize, Deserialize)]
86pub struct ComplexScalar {
87    name: String,
88    re: ParameterLike,
89    pid_re: ParameterID,
90    im: ParameterLike,
91    pid_im: ParameterID,
92}
93
94impl ComplexScalar {
95    /// Create a new [`ComplexScalar`] with the given name, real, and imaginary part.
96    pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
97        Self {
98            name: name.to_string(),
99            re,
100            pid_re: Default::default(),
101            im,
102            pid_im: Default::default(),
103        }
104        .into()
105    }
106}
107
108#[typetag::serde]
109impl Amplitude for ComplexScalar {
110    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
111        self.pid_re = resources.register_parameter(&self.re);
112        self.pid_im = resources.register_parameter(&self.im);
113        resources.register_amplitude(&self.name)
114    }
115
116    fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
117        Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
118    }
119
120    fn compute_gradient(
121        &self,
122        _parameters: &Parameters,
123        _event: &Event,
124        _cache: &Cache,
125        gradient: &mut DVector<Complex<Float>>,
126    ) {
127        if let ParameterID::Parameter(ind) = self.pid_re {
128            gradient[ind] = Complex::ONE;
129        }
130        if let ParameterID::Parameter(ind) = self.pid_im {
131            gradient[ind] = Complex::I;
132        }
133    }
134}
135
136/// An Amplitude which represents a complex value
137///
138/// Parameters
139/// ----------
140/// name : str
141///     The Amplitude name
142/// re: laddu.ParameterLike
143///     The real part of the complex value contained in the Amplitude
144/// im: laddu.ParameterLike
145///     The imaginary part of the complex value contained in the Amplitude
146///
147/// Returns
148/// -------
149/// laddu.Amplitude
150///     An Amplitude which can be registered by a laddu.Manager
151///
152/// See Also
153/// --------
154/// laddu.Manager
155///
156#[cfg(feature = "python")]
157#[pyfunction(name = "ComplexScalar")]
158pub fn py_complex_scalar(name: &str, re: PyParameterLike, im: PyParameterLike) -> PyAmplitude {
159    PyAmplitude(ComplexScalar::new(name, re.0, im.0))
160}
161
162/// A complex-valued [`Amplitude`] which just contains two parameters representing its magnitude and
163/// phase.
164#[derive(Clone, Serialize, Deserialize)]
165pub struct PolarComplexScalar {
166    name: String,
167    r: ParameterLike,
168    pid_r: ParameterID,
169    theta: ParameterLike,
170    pid_theta: ParameterID,
171}
172
173impl PolarComplexScalar {
174    /// Create a new [`PolarComplexScalar`] with the given name, magnitude (`r`), and phase (`theta`).
175    pub fn new(name: &str, r: ParameterLike, theta: ParameterLike) -> Box<Self> {
176        Self {
177            name: name.to_string(),
178            r,
179            pid_r: Default::default(),
180            theta,
181            pid_theta: Default::default(),
182        }
183        .into()
184    }
185}
186
187#[typetag::serde]
188impl Amplitude for PolarComplexScalar {
189    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
190        self.pid_r = resources.register_parameter(&self.r);
191        self.pid_theta = resources.register_parameter(&self.theta);
192        resources.register_amplitude(&self.name)
193    }
194
195    fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
196        Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta))
197    }
198
199    fn compute_gradient(
200        &self,
201        parameters: &Parameters,
202        _event: &Event,
203        _cache: &Cache,
204        gradient: &mut DVector<Complex<Float>>,
205    ) {
206        let exp_i_theta = Complex::cis(parameters.get(self.pid_theta));
207        if let ParameterID::Parameter(ind) = self.pid_r {
208            gradient[ind] = exp_i_theta;
209        }
210        if let ParameterID::Parameter(ind) = self.pid_theta {
211            gradient[ind] = Complex::<Float>::I
212                * Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta));
213        }
214    }
215}
216
217/// An Amplitude which represents a complex scalar value in polar form
218///
219/// Parameters
220/// ----------
221/// name : str
222///     The Amplitude name
223/// r: laddu.ParameterLike
224///     The magnitude of the complex value contained in the Amplitude
225/// theta: laddu.ParameterLike
226///     The argument of the complex value contained in the Amplitude
227///
228/// Returns
229/// -------
230/// laddu.Amplitude
231///     An Amplitude which can be registered by a laddu.Manager
232///
233/// See Also
234/// --------
235/// laddu.Manager
236///
237#[cfg(feature = "python")]
238#[pyfunction(name = "PolarComplexScalar")]
239pub fn py_polar_complex_scalar(
240    name: &str,
241    r: PyParameterLike,
242    theta: PyParameterLike,
243) -> PyAmplitude {
244    PyAmplitude(PolarComplexScalar::new(name, r.0, theta.0))
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use approx::assert_relative_eq;
251    use laddu_core::{data::test_dataset, parameter, Manager, PI};
252    use std::sync::Arc;
253
254    #[test]
255    fn test_scalar_creation_and_evaluation() {
256        let mut manager = Manager::default();
257        let amp = Scalar::new("test_scalar", parameter("test_param"));
258        let aid = manager.register(amp).unwrap();
259
260        let dataset = Arc::new(test_dataset());
261        let expr = aid.into(); // Direct amplitude evaluation
262        let model = manager.model(&expr);
263        let evaluator = model.load(&dataset);
264
265        let params = vec![2.5];
266        let result = evaluator.evaluate(&params);
267
268        assert_relative_eq!(result[0].re, 2.5);
269        assert_relative_eq!(result[0].im, 0.0);
270    }
271
272    #[test]
273    fn test_scalar_gradient() {
274        let mut manager = Manager::default();
275        let amp = Scalar::new("test_scalar", parameter("test_param"));
276        let aid = manager.register(amp).unwrap();
277
278        let dataset = Arc::new(test_dataset());
279        let expr = aid.norm_sqr(); // |f(x)|^2
280        let model = manager.model(&expr);
281        let evaluator = model.load(&dataset);
282
283        let params = vec![2.0];
284        let gradient = evaluator.evaluate_gradient(&params);
285
286        // For |f(x)|^2 where f(x) = x, the derivative should be 2x
287        assert_relative_eq!(gradient[0][0].re, 4.0);
288        assert_relative_eq!(gradient[0][0].im, 0.0);
289    }
290
291    #[test]
292    fn test_complex_scalar_evaluation() {
293        let mut manager = Manager::default();
294        let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
295        let aid = manager.register(amp).unwrap();
296
297        let dataset = Arc::new(test_dataset());
298        let expr = aid.into();
299        let model = manager.model(&expr);
300        let evaluator = model.load(&dataset);
301
302        let params = vec![1.5, 2.5]; // Real and imaginary parts
303        let result = evaluator.evaluate(&params);
304
305        assert_relative_eq!(result[0].re, 1.5);
306        assert_relative_eq!(result[0].im, 2.5);
307    }
308
309    #[test]
310    fn test_complex_scalar_gradient() {
311        let mut manager = Manager::default();
312        let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
313        let aid = manager.register(amp).unwrap();
314
315        let dataset = Arc::new(test_dataset());
316        let expr = aid.norm_sqr(); // |f(x + iy)|^2
317        let model = manager.model(&expr);
318        let evaluator = model.load(&dataset);
319
320        let params = vec![3.0, 4.0]; // Real and imaginary parts
321        let gradient = evaluator.evaluate_gradient(&params);
322
323        // For |f(x + iy)|^2, partial derivatives should be 2x and 2y
324        assert_relative_eq!(gradient[0][0].re, 6.0);
325        assert_relative_eq!(gradient[0][0].im, 0.0);
326        assert_relative_eq!(gradient[0][1].re, 8.0);
327        assert_relative_eq!(gradient[0][1].im, 0.0);
328    }
329
330    #[test]
331    fn test_polar_complex_scalar_evaluation() {
332        let mut manager = Manager::default();
333        let amp =
334            PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
335        let aid = manager.register(amp).unwrap();
336
337        let dataset = Arc::new(test_dataset());
338        let expr = aid.into();
339        let model = manager.model(&expr);
340        let evaluator = model.load(&dataset);
341
342        let r = 2.0;
343        let theta = PI / 4.3;
344        let params = vec![r, theta];
345        let result = evaluator.evaluate(&params);
346
347        // r * (cos(theta) + i*sin(theta))
348        assert_relative_eq!(result[0].re, r * theta.cos());
349        assert_relative_eq!(result[0].im, r * theta.sin());
350    }
351
352    #[test]
353    fn test_polar_complex_scalar_gradient() {
354        let mut manager = Manager::default();
355        let amp =
356            PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
357        let aid = manager.register(amp).unwrap();
358
359        let dataset = Arc::new(test_dataset());
360        let expr = aid.into(); // f(r,θ) = re^(iθ)
361        let model = manager.model(&expr);
362        let evaluator = model.load(&dataset);
363
364        let r = 2.0;
365        let theta = PI / 4.3;
366        let params = vec![r, theta];
367        let gradient = evaluator.evaluate_gradient(&params);
368
369        // d/dr re^(iθ) = e^(iθ), d/dθ re^(iθ) = ire^(iθ)
370        assert_relative_eq!(gradient[0][0].re, Float::cos(theta));
371        assert_relative_eq!(gradient[0][0].im, Float::sin(theta));
372        assert_relative_eq!(gradient[0][1].re, -r * Float::sin(theta));
373        assert_relative_eq!(gradient[0][1].im, r * Float::cos(theta));
374    }
375}