1use laddu_core::{
2 amplitudes::{Amplitude, AmplitudeID, ParameterLike},
3 data::Event,
4 resources::{Cache, ParameterID, Parameters, Resources},
5 Complex, DVector, Float, LadduError,
6};
7use serde::{Deserialize, Serialize};
8
9#[cfg(feature = "python")]
10use laddu_python::amplitudes::{PyAmplitude, PyParameterLike};
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13
14#[derive(Clone, Serialize, Deserialize)]
16pub struct Scalar {
17 name: String,
18 value: ParameterLike,
19 pid: ParameterID,
20}
21
22impl Scalar {
23 pub fn new(name: &str, value: ParameterLike) -> Box<Self> {
25 Self {
26 name: name.to_string(),
27 value,
28 pid: Default::default(),
29 }
30 .into()
31 }
32}
33
34#[typetag::serde]
35impl Amplitude for Scalar {
36 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
37 self.pid = resources.register_parameter(&self.value);
38 resources.register_amplitude(&self.name)
39 }
40
41 fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
42 Complex::new(parameters.get(self.pid), 0.0)
43 }
44
45 fn compute_gradient(
46 &self,
47 _parameters: &Parameters,
48 _event: &Event,
49 _cache: &Cache,
50 gradient: &mut DVector<Complex<Float>>,
51 ) {
52 if let ParameterID::Parameter(ind) = self.pid {
53 gradient[ind] = Complex::ONE;
54 }
55 }
56}
57
58#[cfg(feature = "python")]
77#[pyfunction(name = "Scalar")]
78pub fn py_scalar(name: &str, value: PyParameterLike) -> PyAmplitude {
79 PyAmplitude(Scalar::new(name, value.0))
80}
81
82#[derive(Clone, Serialize, Deserialize)]
85pub struct ComplexScalar {
86 name: String,
87 re: ParameterLike,
88 pid_re: ParameterID,
89 im: ParameterLike,
90 pid_im: ParameterID,
91}
92
93impl ComplexScalar {
94 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
96 Self {
97 name: name.to_string(),
98 re,
99 pid_re: Default::default(),
100 im,
101 pid_im: Default::default(),
102 }
103 .into()
104 }
105}
106
107#[typetag::serde]
108impl Amplitude for ComplexScalar {
109 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
110 self.pid_re = resources.register_parameter(&self.re);
111 self.pid_im = resources.register_parameter(&self.im);
112 resources.register_amplitude(&self.name)
113 }
114
115 fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
116 Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
117 }
118
119 fn compute_gradient(
120 &self,
121 _parameters: &Parameters,
122 _event: &Event,
123 _cache: &Cache,
124 gradient: &mut DVector<Complex<Float>>,
125 ) {
126 if let ParameterID::Parameter(ind) = self.pid_re {
127 gradient[ind] = Complex::ONE;
128 }
129 if let ParameterID::Parameter(ind) = self.pid_im {
130 gradient[ind] = Complex::I;
131 }
132 }
133}
134
135#[cfg(feature = "python")]
156#[pyfunction(name = "ComplexScalar")]
157pub fn py_complex_scalar(name: &str, re: PyParameterLike, im: PyParameterLike) -> PyAmplitude {
158 PyAmplitude(ComplexScalar::new(name, re.0, im.0))
159}
160
161#[derive(Clone, Serialize, Deserialize)]
164pub struct PolarComplexScalar {
165 name: String,
166 r: ParameterLike,
167 pid_r: ParameterID,
168 theta: ParameterLike,
169 pid_theta: ParameterID,
170}
171
172impl PolarComplexScalar {
173 pub fn new(name: &str, r: ParameterLike, theta: ParameterLike) -> Box<Self> {
175 Self {
176 name: name.to_string(),
177 r,
178 pid_r: Default::default(),
179 theta,
180 pid_theta: Default::default(),
181 }
182 .into()
183 }
184}
185
186#[typetag::serde]
187impl Amplitude for PolarComplexScalar {
188 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
189 self.pid_r = resources.register_parameter(&self.r);
190 self.pid_theta = resources.register_parameter(&self.theta);
191 resources.register_amplitude(&self.name)
192 }
193
194 fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
195 Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta))
196 }
197
198 fn compute_gradient(
199 &self,
200 parameters: &Parameters,
201 _event: &Event,
202 _cache: &Cache,
203 gradient: &mut DVector<Complex<Float>>,
204 ) {
205 let exp_i_theta = Complex::cis(parameters.get(self.pid_theta));
206 if let ParameterID::Parameter(ind) = self.pid_r {
207 gradient[ind] = exp_i_theta;
208 }
209 if let ParameterID::Parameter(ind) = self.pid_theta {
210 gradient[ind] = Complex::<Float>::I
211 * Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta));
212 }
213 }
214}
215
216#[cfg(feature = "python")]
237#[pyfunction(name = "PolarComplexScalar")]
238pub fn py_polar_complex_scalar(
239 name: &str,
240 r: PyParameterLike,
241 theta: PyParameterLike,
242) -> PyAmplitude {
243 PyAmplitude(PolarComplexScalar::new(name, r.0, theta.0))
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use approx::assert_relative_eq;
250 use laddu_core::{data::test_dataset, parameter, Manager, PI};
251 use std::sync::Arc;
252
253 #[test]
254 fn test_scalar_creation_and_evaluation() {
255 let mut manager = Manager::default();
256 let amp = Scalar::new("test_scalar", parameter("test_param"));
257 let aid = manager.register(amp).unwrap();
258
259 let dataset = Arc::new(test_dataset());
260 let expr = aid.into(); let model = manager.model(&expr);
262 let evaluator = model.load(&dataset);
263
264 let params = vec![2.5];
265 let result = evaluator.evaluate(¶ms);
266
267 assert_relative_eq!(result[0].re, 2.5);
268 assert_relative_eq!(result[0].im, 0.0);
269 }
270
271 #[test]
272 fn test_scalar_gradient() {
273 let mut manager = Manager::default();
274 let amp = Scalar::new("test_scalar", parameter("test_param"));
275 let aid = manager.register(amp).unwrap();
276
277 let dataset = Arc::new(test_dataset());
278 let expr = aid.norm_sqr(); let model = manager.model(&expr);
280 let evaluator = model.load(&dataset);
281
282 let params = vec![2.0];
283 let gradient = evaluator.evaluate_gradient(¶ms);
284
285 assert_relative_eq!(gradient[0][0].re, 4.0);
287 assert_relative_eq!(gradient[0][0].im, 0.0);
288 }
289
290 #[test]
291 fn test_complex_scalar_evaluation() {
292 let mut manager = Manager::default();
293 let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
294 let aid = manager.register(amp).unwrap();
295
296 let dataset = Arc::new(test_dataset());
297 let expr = aid.into();
298 let model = manager.model(&expr);
299 let evaluator = model.load(&dataset);
300
301 let params = vec![1.5, 2.5]; let result = evaluator.evaluate(¶ms);
303
304 assert_relative_eq!(result[0].re, 1.5);
305 assert_relative_eq!(result[0].im, 2.5);
306 }
307
308 #[test]
309 fn test_complex_scalar_gradient() {
310 let mut manager = Manager::default();
311 let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
312 let aid = manager.register(amp).unwrap();
313
314 let dataset = Arc::new(test_dataset());
315 let expr = aid.norm_sqr(); let model = manager.model(&expr);
317 let evaluator = model.load(&dataset);
318
319 let params = vec![3.0, 4.0]; let gradient = evaluator.evaluate_gradient(¶ms);
321
322 assert_relative_eq!(gradient[0][0].re, 6.0);
324 assert_relative_eq!(gradient[0][0].im, 0.0);
325 assert_relative_eq!(gradient[0][1].re, 8.0);
326 assert_relative_eq!(gradient[0][1].im, 0.0);
327 }
328
329 #[test]
330 fn test_polar_complex_scalar_evaluation() {
331 let mut manager = Manager::default();
332 let amp =
333 PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
334 let aid = manager.register(amp).unwrap();
335
336 let dataset = Arc::new(test_dataset());
337 let expr = aid.into();
338 let model = manager.model(&expr);
339 let evaluator = model.load(&dataset);
340
341 let r = 2.0;
342 let theta = PI / 4.3;
343 let params = vec![r, theta];
344 let result = evaluator.evaluate(¶ms);
345
346 assert_relative_eq!(result[0].re, r * theta.cos());
348 assert_relative_eq!(result[0].im, r * theta.sin());
349 }
350
351 #[test]
352 fn test_polar_complex_scalar_gradient() {
353 let mut manager = Manager::default();
354 let amp =
355 PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
356 let aid = manager.register(amp).unwrap();
357
358 let dataset = Arc::new(test_dataset());
359 let expr = aid.into(); let model = manager.model(&expr);
361 let evaluator = model.load(&dataset);
362
363 let r = 2.0;
364 let theta = PI / 4.3;
365 let params = vec![r, theta];
366 let gradient = evaluator.evaluate_gradient(¶ms);
367
368 assert_relative_eq!(gradient[0][0].re, Float::cos(theta));
370 assert_relative_eq!(gradient[0][0].im, Float::sin(theta));
371 assert_relative_eq!(gradient[0][1].re, -r * Float::sin(theta));
372 assert_relative_eq!(gradient[0][1].im, r * Float::cos(theta));
373 }
374}