1use laddu_core::{
2 amplitudes::{Amplitude, AmplitudeID, ParameterLike},
3 data::Event,
4 resources::{Cache, ParameterID, Parameters, Resources},
5 Float, LadduError,
6};
7#[cfg(feature = "python")]
8use laddu_python::amplitudes::{PyAmplitude, PyParameterLike};
9use nalgebra::DVector;
10use num::Complex;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13use serde::{Deserialize, Serialize};
14
15#[derive(Clone, Serialize, Deserialize)]
17pub struct Scalar {
18 name: String,
19 value: ParameterLike,
20 pid: ParameterID,
21}
22
23impl Scalar {
24 pub fn new(name: &str, value: ParameterLike) -> Box<Self> {
26 Self {
27 name: name.to_string(),
28 value,
29 pid: Default::default(),
30 }
31 .into()
32 }
33}
34
35#[typetag::serde]
36impl Amplitude for Scalar {
37 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
38 self.pid = resources.register_parameter(&self.value);
39 resources.register_amplitude(&self.name)
40 }
41
42 fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
43 Complex::new(parameters.get(self.pid), 0.0)
44 }
45
46 fn compute_gradient(
47 &self,
48 _parameters: &Parameters,
49 _event: &Event,
50 _cache: &Cache,
51 gradient: &mut DVector<Complex<Float>>,
52 ) {
53 if let ParameterID::Parameter(ind) = self.pid {
54 gradient[ind] = Complex::ONE;
55 }
56 }
57}
58
59#[cfg(feature = "python")]
78#[pyfunction(name = "Scalar")]
79pub fn py_scalar(name: &str, value: PyParameterLike) -> PyAmplitude {
80 PyAmplitude(Scalar::new(name, value.0))
81}
82
83#[derive(Clone, Serialize, Deserialize)]
86pub struct ComplexScalar {
87 name: String,
88 re: ParameterLike,
89 pid_re: ParameterID,
90 im: ParameterLike,
91 pid_im: ParameterID,
92}
93
94impl ComplexScalar {
95 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
97 Self {
98 name: name.to_string(),
99 re,
100 pid_re: Default::default(),
101 im,
102 pid_im: Default::default(),
103 }
104 .into()
105 }
106}
107
108#[typetag::serde]
109impl Amplitude for ComplexScalar {
110 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
111 self.pid_re = resources.register_parameter(&self.re);
112 self.pid_im = resources.register_parameter(&self.im);
113 resources.register_amplitude(&self.name)
114 }
115
116 fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
117 Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
118 }
119
120 fn compute_gradient(
121 &self,
122 _parameters: &Parameters,
123 _event: &Event,
124 _cache: &Cache,
125 gradient: &mut DVector<Complex<Float>>,
126 ) {
127 if let ParameterID::Parameter(ind) = self.pid_re {
128 gradient[ind] = Complex::ONE;
129 }
130 if let ParameterID::Parameter(ind) = self.pid_im {
131 gradient[ind] = Complex::I;
132 }
133 }
134}
135
136#[cfg(feature = "python")]
157#[pyfunction(name = "ComplexScalar")]
158pub fn py_complex_scalar(name: &str, re: PyParameterLike, im: PyParameterLike) -> PyAmplitude {
159 PyAmplitude(ComplexScalar::new(name, re.0, im.0))
160}
161
162#[derive(Clone, Serialize, Deserialize)]
165pub struct PolarComplexScalar {
166 name: String,
167 r: ParameterLike,
168 pid_r: ParameterID,
169 theta: ParameterLike,
170 pid_theta: ParameterID,
171}
172
173impl PolarComplexScalar {
174 pub fn new(name: &str, r: ParameterLike, theta: ParameterLike) -> Box<Self> {
176 Self {
177 name: name.to_string(),
178 r,
179 pid_r: Default::default(),
180 theta,
181 pid_theta: Default::default(),
182 }
183 .into()
184 }
185}
186
187#[typetag::serde]
188impl Amplitude for PolarComplexScalar {
189 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
190 self.pid_r = resources.register_parameter(&self.r);
191 self.pid_theta = resources.register_parameter(&self.theta);
192 resources.register_amplitude(&self.name)
193 }
194
195 fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
196 Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta))
197 }
198
199 fn compute_gradient(
200 &self,
201 parameters: &Parameters,
202 _event: &Event,
203 _cache: &Cache,
204 gradient: &mut DVector<Complex<Float>>,
205 ) {
206 let exp_i_theta = Complex::cis(parameters.get(self.pid_theta));
207 if let ParameterID::Parameter(ind) = self.pid_r {
208 gradient[ind] = exp_i_theta;
209 }
210 if let ParameterID::Parameter(ind) = self.pid_theta {
211 gradient[ind] = Complex::<Float>::I
212 * Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta));
213 }
214 }
215}
216
217#[cfg(feature = "python")]
238#[pyfunction(name = "PolarComplexScalar")]
239pub fn py_polar_complex_scalar(
240 name: &str,
241 r: PyParameterLike,
242 theta: PyParameterLike,
243) -> PyAmplitude {
244 PyAmplitude(PolarComplexScalar::new(name, r.0, theta.0))
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use approx::assert_relative_eq;
251 use laddu_core::{data::test_dataset, parameter, Manager, PI};
252 use std::sync::Arc;
253
254 #[test]
255 fn test_scalar_creation_and_evaluation() {
256 let mut manager = Manager::default();
257 let amp = Scalar::new("test_scalar", parameter("test_param"));
258 let aid = manager.register(amp).unwrap();
259
260 let dataset = Arc::new(test_dataset());
261 let expr = aid.into(); let model = manager.model(&expr);
263 let evaluator = model.load(&dataset);
264
265 let params = vec![2.5];
266 let result = evaluator.evaluate(¶ms);
267
268 assert_relative_eq!(result[0].re, 2.5);
269 assert_relative_eq!(result[0].im, 0.0);
270 }
271
272 #[test]
273 fn test_scalar_gradient() {
274 let mut manager = Manager::default();
275 let amp = Scalar::new("test_scalar", parameter("test_param"));
276 let aid = manager.register(amp).unwrap();
277
278 let dataset = Arc::new(test_dataset());
279 let expr = aid.norm_sqr(); let model = manager.model(&expr);
281 let evaluator = model.load(&dataset);
282
283 let params = vec![2.0];
284 let gradient = evaluator.evaluate_gradient(¶ms);
285
286 assert_relative_eq!(gradient[0][0].re, 4.0);
288 assert_relative_eq!(gradient[0][0].im, 0.0);
289 }
290
291 #[test]
292 fn test_complex_scalar_evaluation() {
293 let mut manager = Manager::default();
294 let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
295 let aid = manager.register(amp).unwrap();
296
297 let dataset = Arc::new(test_dataset());
298 let expr = aid.into();
299 let model = manager.model(&expr);
300 let evaluator = model.load(&dataset);
301
302 let params = vec![1.5, 2.5]; let result = evaluator.evaluate(¶ms);
304
305 assert_relative_eq!(result[0].re, 1.5);
306 assert_relative_eq!(result[0].im, 2.5);
307 }
308
309 #[test]
310 fn test_complex_scalar_gradient() {
311 let mut manager = Manager::default();
312 let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
313 let aid = manager.register(amp).unwrap();
314
315 let dataset = Arc::new(test_dataset());
316 let expr = aid.norm_sqr(); let model = manager.model(&expr);
318 let evaluator = model.load(&dataset);
319
320 let params = vec![3.0, 4.0]; let gradient = evaluator.evaluate_gradient(¶ms);
322
323 assert_relative_eq!(gradient[0][0].re, 6.0);
325 assert_relative_eq!(gradient[0][0].im, 0.0);
326 assert_relative_eq!(gradient[0][1].re, 8.0);
327 assert_relative_eq!(gradient[0][1].im, 0.0);
328 }
329
330 #[test]
331 fn test_polar_complex_scalar_evaluation() {
332 let mut manager = Manager::default();
333 let amp =
334 PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
335 let aid = manager.register(amp).unwrap();
336
337 let dataset = Arc::new(test_dataset());
338 let expr = aid.into();
339 let model = manager.model(&expr);
340 let evaluator = model.load(&dataset);
341
342 let r = 2.0;
343 let theta = PI / 4.3;
344 let params = vec![r, theta];
345 let result = evaluator.evaluate(¶ms);
346
347 assert_relative_eq!(result[0].re, r * theta.cos());
349 assert_relative_eq!(result[0].im, r * theta.sin());
350 }
351
352 #[test]
353 fn test_polar_complex_scalar_gradient() {
354 let mut manager = Manager::default();
355 let amp =
356 PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
357 let aid = manager.register(amp).unwrap();
358
359 let dataset = Arc::new(test_dataset());
360 let expr = aid.into(); let model = manager.model(&expr);
362 let evaluator = model.load(&dataset);
363
364 let r = 2.0;
365 let theta = PI / 4.3;
366 let params = vec![r, theta];
367 let gradient = evaluator.evaluate_gradient(¶ms);
368
369 assert_relative_eq!(gradient[0][0].re, Float::cos(theta));
371 assert_relative_eq!(gradient[0][0].im, Float::sin(theta));
372 assert_relative_eq!(gradient[0][1].re, -r * Float::sin(theta));
373 assert_relative_eq!(gradient[0][1].im, r * Float::cos(theta));
374 }
375}