1use serde::{Deserialize, Serialize};
2
3use laddu_core::{
4 amplitudes::{Amplitude, AmplitudeID, ParameterLike},
5 data::Event,
6 resources::{Cache, ParameterID, Parameters, Resources},
7 traits::Variable,
8 utils::get_bin_index,
9 Complex, DVector, Float, LadduError, ScalarID,
10};
11
12#[cfg(feature = "python")]
13use laddu_python::{
14 amplitudes::{PyAmplitude, PyParameterLike},
15 utils::variables::PyVariable,
16};
17#[cfg(feature = "python")]
18use pyo3::prelude::*;
19
20#[derive(Clone, Serialize, Deserialize)]
22pub struct PiecewiseScalar {
23 name: String,
24 variable: Box<dyn Variable>,
25 bins: usize,
26 range: (Float, Float),
27 values: Vec<ParameterLike>,
28 pids: Vec<ParameterID>,
29 bin_index: ScalarID,
30}
31impl PiecewiseScalar {
32 pub fn new<V: Variable + 'static>(
34 name: &str,
35 variable: &V,
36 bins: usize,
37 range: (Float, Float),
38 values: Vec<ParameterLike>,
39 ) -> Box<Self> {
40 assert_eq!(
41 bins,
42 values.len(),
43 "Number of bins must match number of parameters!"
44 );
45 Self {
46 name: name.to_string(),
47 variable: dyn_clone::clone_box(variable),
48 bins,
49 range,
50 values,
51 pids: Default::default(),
52 bin_index: Default::default(),
53 }
54 .into()
55 }
56}
57
58#[typetag::serde]
59impl Amplitude for PiecewiseScalar {
60 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
61 self.pids = self
62 .values
63 .iter()
64 .map(|value| resources.register_parameter(value))
65 .collect();
66 self.bin_index = resources.register_scalar(None);
67 resources.register_amplitude(&self.name)
68 }
69
70 fn precompute(&self, event: &Event, cache: &mut Cache) {
71 let maybe_bin_index = get_bin_index(self.variable.value(event), self.bins, self.range);
72 if let Some(bin_index) = maybe_bin_index {
73 cache.store_scalar(self.bin_index, bin_index as Float);
74 } else {
75 cache.store_scalar(self.bin_index, (self.bins + 1) as Float);
76 }
78 }
79
80 fn compute(&self, parameters: &Parameters, _event: &Event, cache: &Cache) -> Complex<Float> {
81 let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
82 if bin_index == self.bins + 1 {
83 Complex::ZERO
84 } else {
85 Complex::from(parameters.get(self.pids[bin_index]))
86 }
87 }
88
89 fn compute_gradient(
90 &self,
91 _parameters: &Parameters,
92 _event: &Event,
93 cache: &Cache,
94 gradient: &mut DVector<Complex<Float>>,
95 ) {
96 let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
97 if bin_index < self.bins + 1 {
98 gradient[bin_index] = Complex::ONE;
99 }
100 }
101}
102
103#[cfg(feature = "python")]
141#[pyfunction(name = "PiecewiseScalar")]
142pub fn py_piecewise_scalar(
143 name: &str,
144 variable: Bound<'_, PyAny>,
145 bins: usize,
146 range: (Float, Float),
147 values: Vec<PyParameterLike>,
148) -> PyResult<PyAmplitude> {
149 let variable = variable.extract::<PyVariable>()?;
150 Ok(PyAmplitude(PiecewiseScalar::new(
151 name,
152 &variable,
153 bins,
154 range,
155 values.into_iter().map(|value| value.0).collect(),
156 )))
157}
158
159#[derive(Clone, Serialize, Deserialize)]
162pub struct PiecewiseComplexScalar {
163 name: String,
164 variable: Box<dyn Variable>,
165 bins: usize,
166 range: (Float, Float),
167 re_ims: Vec<(ParameterLike, ParameterLike)>,
168 pids_re_im: Vec<(ParameterID, ParameterID)>,
169 bin_index: ScalarID,
170}
171impl PiecewiseComplexScalar {
172 pub fn new<V: Variable + 'static>(
174 name: &str,
175 variable: &V,
176 bins: usize,
177 range: (Float, Float),
178 re_ims: Vec<(ParameterLike, ParameterLike)>,
179 ) -> Box<Self> {
180 assert_eq!(
181 bins,
182 re_ims.len(),
183 "Number of bins must match number of parameters!"
184 );
185 Self {
186 name: name.to_string(),
187 variable: dyn_clone::clone_box(variable),
188 bins,
189 range,
190 re_ims,
191 pids_re_im: Default::default(),
192 bin_index: Default::default(),
193 }
194 .into()
195 }
196}
197
198#[typetag::serde]
199impl Amplitude for PiecewiseComplexScalar {
200 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
201 self.pids_re_im = self
202 .re_ims
203 .iter()
204 .map(|(re, im)| {
205 (
206 resources.register_parameter(re),
207 resources.register_parameter(im),
208 )
209 })
210 .collect();
211 self.bin_index = resources.register_scalar(None);
212 resources.register_amplitude(&self.name)
213 }
214
215 fn precompute(&self, event: &Event, cache: &mut Cache) {
216 let maybe_bin_index = get_bin_index(self.variable.value(event), self.bins, self.range);
217 if let Some(bin_index) = maybe_bin_index {
218 cache.store_scalar(self.bin_index, bin_index as Float);
219 } else {
220 cache.store_scalar(self.bin_index, (self.bins + 1) as Float);
221 }
223 }
224
225 fn compute(&self, parameters: &Parameters, _event: &Event, cache: &Cache) -> Complex<Float> {
226 let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
227 if bin_index == self.bins + 1 {
228 Complex::ZERO
229 } else {
230 let pid_re_im = self.pids_re_im[bin_index];
231 Complex::new(parameters.get(pid_re_im.0), parameters.get(pid_re_im.1))
232 }
233 }
234
235 fn compute_gradient(
236 &self,
237 _parameters: &Parameters,
238 _event: &Event,
239 cache: &Cache,
240 gradient: &mut DVector<Complex<Float>>,
241 ) {
242 let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
243 if bin_index < self.bins + 1 {
244 let pid_re_im = self.pids_re_im[bin_index];
245 if let ParameterID::Parameter(ind) = pid_re_im.0 {
246 gradient[ind] = Complex::ONE;
247 }
248 if let ParameterID::Parameter(ind) = pid_re_im.1 {
249 gradient[ind] = Complex::I;
250 }
251 }
252 }
253}
254
255#[cfg(feature = "python")]
294#[pyfunction(name = "PiecewiseComplexScalar")]
295pub fn py_piecewise_complex_scalar(
296 name: &str,
297 variable: Bound<'_, PyAny>,
298 bins: usize,
299 range: (Float, Float),
300 values: Vec<(PyParameterLike, PyParameterLike)>,
301) -> PyResult<PyAmplitude> {
302 let variable = variable.extract::<PyVariable>()?;
303 Ok(PyAmplitude(PiecewiseComplexScalar::new(
304 name,
305 &variable,
306 bins,
307 range,
308 values
309 .into_iter()
310 .map(|(value_re, value_im)| (value_re.0, value_im.0))
311 .collect(),
312 )))
313}
314
315#[derive(Clone, Serialize, Deserialize)]
318pub struct PiecewisePolarComplexScalar {
319 name: String,
320 variable: Box<dyn Variable>,
321 bins: usize,
322 range: (Float, Float),
323 r_thetas: Vec<(ParameterLike, ParameterLike)>,
324 pids_r_theta: Vec<(ParameterID, ParameterID)>,
325 bin_index: ScalarID,
326}
327impl PiecewisePolarComplexScalar {
328 pub fn new<V: Variable + 'static>(
330 name: &str,
331 variable: &V,
332 bins: usize,
333 range: (Float, Float),
334 r_thetas: Vec<(ParameterLike, ParameterLike)>,
335 ) -> Box<Self> {
336 assert_eq!(
337 bins,
338 r_thetas.len(),
339 "Number of bins must match number of parameters!"
340 );
341 Self {
342 name: name.to_string(),
343 variable: dyn_clone::clone_box(variable),
344 bins,
345 range,
346 r_thetas,
347 pids_r_theta: Default::default(),
348 bin_index: Default::default(),
349 }
350 .into()
351 }
352}
353
354#[typetag::serde]
355impl Amplitude for PiecewisePolarComplexScalar {
356 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
357 self.pids_r_theta = self
358 .r_thetas
359 .iter()
360 .map(|(r, theta)| {
361 (
362 resources.register_parameter(r),
363 resources.register_parameter(theta),
364 )
365 })
366 .collect();
367 self.bin_index = resources.register_scalar(None);
368 resources.register_amplitude(&self.name)
369 }
370
371 fn precompute(&self, event: &Event, cache: &mut Cache) {
372 let maybe_bin_index = get_bin_index(self.variable.value(event), self.bins, self.range);
373 if let Some(bin_index) = maybe_bin_index {
374 cache.store_scalar(self.bin_index, bin_index as Float);
375 } else {
376 cache.store_scalar(self.bin_index, (self.bins + 1) as Float);
377 }
379 }
380
381 fn compute(&self, parameters: &Parameters, _event: &Event, cache: &Cache) -> Complex<Float> {
382 let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
383 if bin_index == self.bins + 1 {
384 Complex::ZERO
385 } else {
386 let pid_r_theta = self.pids_r_theta[bin_index];
387 Complex::from_polar(parameters.get(pid_r_theta.0), parameters.get(pid_r_theta.1))
388 }
389 }
390
391 fn compute_gradient(
392 &self,
393 parameters: &Parameters,
394 _event: &Event,
395 cache: &Cache,
396 gradient: &mut DVector<Complex<Float>>,
397 ) {
398 let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
399 if bin_index < self.bins + 1 {
400 let pid_r_theta = self.pids_r_theta[bin_index];
401 let r = parameters.get(pid_r_theta.0);
402 let theta = parameters.get(pid_r_theta.1);
403 let exp_i_theta = Complex::cis(theta);
404 if let ParameterID::Parameter(ind) = pid_r_theta.0 {
405 gradient[ind] = exp_i_theta;
406 }
407 if let ParameterID::Parameter(ind) = pid_r_theta.1 {
408 gradient[ind] = Complex::<Float>::I * Complex::from_polar(r, theta);
409 }
410 }
411 }
412}
413
414#[cfg(feature = "python")]
453#[pyfunction(name = "PiecewisePolarComplexScalar")]
454pub fn py_piecewise_polar_complex_scalar(
455 name: &str,
456 variable: Bound<'_, PyAny>,
457 bins: usize,
458 range: (Float, Float),
459 values: Vec<(PyParameterLike, PyParameterLike)>,
460) -> PyResult<PyAmplitude> {
461 let variable = variable.extract::<PyVariable>()?;
462 Ok(PyAmplitude(PiecewisePolarComplexScalar::new(
463 name,
464 &variable,
465 bins,
466 range,
467 values
468 .into_iter()
469 .map(|(value_re, value_im)| (value_re.0, value_im.0))
470 .collect(),
471 )))
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use approx::assert_relative_eq;
478 use laddu_core::{data::test_dataset, parameter, Manager, Mass, PI};
479 use std::sync::Arc;
480
481 #[test]
482 fn test_piecewise_scalar_creation_and_evaluation() {
483 let mut manager = Manager::default();
484 let v = Mass::new([2]);
485 let amp = PiecewiseScalar::new(
486 "test_scalar",
487 &v,
488 3,
489 (0.0, 1.0),
490 vec![
491 parameter("test_param0"),
492 parameter("test_param1"),
493 parameter("test_param2"),
494 ],
495 );
496 let aid = manager.register(amp).unwrap();
497
498 let dataset = Arc::new(test_dataset());
499 let expr = aid.into(); let model = manager.model(&expr);
501 let evaluator = model.load(&dataset);
502
503 let params = vec![1.1, 2.2, 3.3];
504 let result = evaluator.evaluate(¶ms);
505
506 assert_relative_eq!(result[0].re, 2.2);
507 assert_relative_eq!(result[0].im, 0.0);
508 }
509
510 #[test]
511 fn test_piecewise_scalar_gradient() {
512 let mut manager = Manager::default();
513 let v = Mass::new([2]);
514 let amp = PiecewiseScalar::new(
515 "test_scalar",
516 &v,
517 3,
518 (0.0, 1.0),
519 vec![
520 parameter("test_param0"),
521 parameter("test_param1"),
522 parameter("test_param2"),
523 ],
524 );
525 let aid = manager.register(amp).unwrap();
526
527 let dataset = Arc::new(test_dataset());
528 let expr = aid.norm_sqr(); let model = manager.model(&expr);
530 let evaluator = model.load(&dataset);
531
532 let params = vec![1.0, 2.0, 3.0];
533 let gradient = evaluator.evaluate_gradient(¶ms);
534
535 assert_relative_eq!(gradient[0][0].re, 0.0);
537 assert_relative_eq!(gradient[0][0].im, 0.0);
538 assert_relative_eq!(gradient[0][1].re, 4.0);
539 assert_relative_eq!(gradient[0][1].im, 0.0);
540 assert_relative_eq!(gradient[0][2].re, 0.0);
541 assert_relative_eq!(gradient[0][2].im, 0.0);
542 }
543
544 #[test]
545 fn test_piecewise_complex_scalar_evaluation() {
546 let mut manager = Manager::default();
547 let v = Mass::new([2]);
548 let amp = PiecewiseComplexScalar::new(
549 "test_complex",
550 &v,
551 3,
552 (0.0, 1.0),
553 vec![
554 (parameter("re_param0"), parameter("im_param0")),
555 (parameter("re_param1"), parameter("im_param1")),
556 (parameter("re_param2"), parameter("im_param2")),
557 ],
558 );
559 let aid = manager.register(amp).unwrap();
560
561 let dataset = Arc::new(test_dataset());
562 let expr = aid.into();
563 let model = manager.model(&expr);
564 let evaluator = model.load(&dataset);
565
566 let params = vec![1.1, 1.2, 2.1, 2.2, 3.1, 3.2]; let result = evaluator.evaluate(¶ms);
568
569 assert_relative_eq!(result[0].re, 2.1);
570 assert_relative_eq!(result[0].im, 2.2);
571 }
572
573 #[test]
574 fn test_piecewise_complex_scalar_gradient() {
575 let mut manager = Manager::default();
576 let v = Mass::new([2]);
577 let amp = PiecewiseComplexScalar::new(
578 "test_complex",
579 &v,
580 3,
581 (0.0, 1.0),
582 vec![
583 (parameter("re_param0"), parameter("im_param0")),
584 (parameter("re_param1"), parameter("im_param1")),
585 (parameter("re_param2"), parameter("im_param2")),
586 ],
587 );
588 let aid = manager.register(amp).unwrap();
589
590 let dataset = Arc::new(test_dataset());
591 let expr = aid.norm_sqr(); let model = manager.model(&expr);
593 let evaluator = model.load(&dataset);
594
595 let params = vec![1.1, 1.2, 2.1, 2.2, 3.1, 3.2]; let gradient = evaluator.evaluate_gradient(¶ms);
597
598 assert_relative_eq!(gradient[0][0].re, 0.0);
600 assert_relative_eq!(gradient[0][0].im, 0.0);
601 assert_relative_eq!(gradient[0][1].re, 0.0);
602 assert_relative_eq!(gradient[0][1].im, 0.0);
603 assert_relative_eq!(gradient[0][2].re, 4.2);
604 assert_relative_eq!(gradient[0][2].im, 0.0);
605 assert_relative_eq!(gradient[0][3].re, 4.4);
606 assert_relative_eq!(gradient[0][3].im, 0.0);
607 assert_relative_eq!(gradient[0][4].re, 0.0);
608 assert_relative_eq!(gradient[0][4].im, 0.0);
609 assert_relative_eq!(gradient[0][5].re, 0.0);
610 assert_relative_eq!(gradient[0][5].im, 0.0);
611 }
612
613 #[test]
614 fn test_piecewise_polar_complex_scalar_evaluation() {
615 let mut manager = Manager::default();
616 let v = Mass::new([2]);
617 let amp = PiecewisePolarComplexScalar::new(
618 "test_polar",
619 &v,
620 3,
621 (0.0, 1.0),
622 vec![
623 (parameter("r_param0"), parameter("theta_param0")),
624 (parameter("r_param1"), parameter("theta_param1")),
625 (parameter("r_param2"), parameter("theta_param2")),
626 ],
627 );
628 let aid = manager.register(amp).unwrap();
629
630 let dataset = Arc::new(test_dataset());
631 let expr = aid.into();
632 let model = manager.model(&expr);
633 let evaluator = model.load(&dataset);
634
635 let r = 2.0;
636 let theta = PI / 4.3;
637 let params = vec![
638 1.1 * r,
639 1.2 * theta,
640 2.1 * r,
641 2.2 * theta,
642 3.1 * r,
643 3.2 * theta,
644 ];
645 let result = evaluator.evaluate(¶ms);
646
647 assert_relative_eq!(result[0].re, 2.1 * r * (2.2 * theta).cos());
649 assert_relative_eq!(result[0].im, 2.1 * r * (2.2 * theta).sin());
650 }
651
652 #[test]
653 fn test_piecewise_polar_complex_scalar_gradient() {
654 let mut manager = Manager::default();
655 let v = Mass::new([2]);
656 let amp = PiecewisePolarComplexScalar::new(
657 "test_polar",
658 &v,
659 3,
660 (0.0, 1.0),
661 vec![
662 (parameter("r_param0"), parameter("theta_param0")),
663 (parameter("r_param1"), parameter("theta_param1")),
664 (parameter("r_param2"), parameter("theta_param2")),
665 ],
666 );
667 let aid = manager.register(amp).unwrap();
668
669 let dataset = Arc::new(test_dataset());
670 let expr = aid.into(); let model = manager.model(&expr);
672 let evaluator = model.load(&dataset);
673
674 let r = 2.0;
675 let theta = PI / 4.3;
676 let params = vec![
677 1.1 * r,
678 1.2 * theta,
679 2.1 * r,
680 2.2 * theta,
681 3.1 * r,
682 3.2 * theta,
683 ];
684 let gradient = evaluator.evaluate_gradient(¶ms);
685
686 assert_relative_eq!(gradient[0][0].re, 0.0);
688 assert_relative_eq!(gradient[0][0].im, 0.0);
689 assert_relative_eq!(gradient[0][1].re, 0.0);
690 assert_relative_eq!(gradient[0][1].im, 0.0);
691 assert_relative_eq!(gradient[0][2].re, Float::cos(2.2 * theta));
692 assert_relative_eq!(gradient[0][2].im, Float::sin(2.2 * theta));
693 assert_relative_eq!(gradient[0][3].re, -2.1 * r * Float::sin(2.2 * theta));
694 assert_relative_eq!(gradient[0][3].im, 2.1 * r * Float::cos(2.2 * theta));
695 assert_relative_eq!(gradient[0][4].re, 0.0);
696 assert_relative_eq!(gradient[0][4].im, 0.0);
697 assert_relative_eq!(gradient[0][5].re, 0.0);
698 assert_relative_eq!(gradient[0][5].im, 0.0);
699 }
700}