laddu_amplitudes/
piecewise.rs

1use serde::{Deserialize, Serialize};
2
3use laddu_core::{
4    amplitudes::{Amplitude, AmplitudeID, ParameterLike},
5    data::Event,
6    resources::{Cache, ParameterID, Parameters, Resources},
7    traits::Variable,
8    utils::get_bin_index,
9    Complex, DVector, Float, LadduError, ScalarID,
10};
11
12#[cfg(feature = "python")]
13use laddu_python::{
14    amplitudes::{PyAmplitude, PyParameterLike},
15    utils::variables::PyVariable,
16};
17#[cfg(feature = "python")]
18use pyo3::prelude::*;
19
20/// A piecewise scalar-valued [`Amplitude`] which just contains a single parameter for each bin as its value.
21#[derive(Clone, Serialize, Deserialize)]
22pub struct PiecewiseScalar {
23    name: String,
24    variable: Box<dyn Variable>,
25    bins: usize,
26    range: (Float, Float),
27    values: Vec<ParameterLike>,
28    pids: Vec<ParameterID>,
29    bin_index: ScalarID,
30}
31impl PiecewiseScalar {
32    /// Create a new [`PiecewiseScalar`] with the given name and parameter value.
33    pub fn new<V: Variable + 'static>(
34        name: &str,
35        variable: &V,
36        bins: usize,
37        range: (Float, Float),
38        values: Vec<ParameterLike>,
39    ) -> Box<Self> {
40        assert_eq!(
41            bins,
42            values.len(),
43            "Number of bins must match number of parameters!"
44        );
45        Self {
46            name: name.to_string(),
47            variable: dyn_clone::clone_box(variable),
48            bins,
49            range,
50            values,
51            pids: Default::default(),
52            bin_index: Default::default(),
53        }
54        .into()
55    }
56}
57
58#[typetag::serde]
59impl Amplitude for PiecewiseScalar {
60    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
61        self.pids = self
62            .values
63            .iter()
64            .map(|value| resources.register_parameter(value))
65            .collect();
66        self.bin_index = resources.register_scalar(None);
67        resources.register_amplitude(&self.name)
68    }
69
70    fn precompute(&self, event: &Event, cache: &mut Cache) {
71        let maybe_bin_index = get_bin_index(self.variable.value(event), self.bins, self.range);
72        if let Some(bin_index) = maybe_bin_index {
73            cache.store_scalar(self.bin_index, bin_index as Float);
74        } else {
75            cache.store_scalar(self.bin_index, (self.bins + 1) as Float);
76            // store ibin = nbins + 1 if outside range
77        }
78    }
79
80    fn compute(&self, parameters: &Parameters, _event: &Event, cache: &Cache) -> Complex<Float> {
81        let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
82        if bin_index == self.bins + 1 {
83            Complex::ZERO
84        } else {
85            Complex::from(parameters.get(self.pids[bin_index]))
86        }
87    }
88
89    fn compute_gradient(
90        &self,
91        _parameters: &Parameters,
92        _event: &Event,
93        cache: &Cache,
94        gradient: &mut DVector<Complex<Float>>,
95    ) {
96        let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
97        if bin_index < self.bins + 1 {
98            gradient[bin_index] = Complex::ONE;
99        }
100    }
101}
102
103/// An Amplitude which represents a piecewise function of single scalar values
104///
105/// Parameters
106/// ----------
107/// name : str
108///     The Amplitude name
109/// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
110///     The variable to use for binning
111/// bins: usize
112///     The number of bins to use
113/// range: tuple of float
114///     The minimum and maximum bin edges
115/// values : list of ParameterLike
116///     The scalar parameters contained in each bin of the Amplitude
117///
118/// Returns
119/// -------
120/// laddu.Amplitude
121///     An Amplitude which can be registered by a laddu.Manager
122///
123/// Raises
124/// ------
125/// AssertionError
126///     If the number of bins does not match the number of parameters
127/// TypeError
128///     If the given `variable` is not a valid variable
129///
130/// See Also
131/// --------
132/// laddu.Manager
133/// laddu.Mass
134/// laddu.CosTheta
135/// laddu.Phi
136/// laddu.PolAngle
137/// laddu.PolMagnitude
138/// laddu.Mandelstam
139///
140#[cfg(feature = "python")]
141#[pyfunction(name = "PiecewiseScalar")]
142pub fn py_piecewise_scalar(
143    name: &str,
144    variable: Bound<'_, PyAny>,
145    bins: usize,
146    range: (Float, Float),
147    values: Vec<PyParameterLike>,
148) -> PyResult<PyAmplitude> {
149    let variable = variable.extract::<PyVariable>()?;
150    Ok(PyAmplitude(PiecewiseScalar::new(
151        name,
152        &variable,
153        bins,
154        range,
155        values.into_iter().map(|value| value.0).collect(),
156    )))
157}
158
159/// A piecewise complex-valued [`Amplitude`] which just contains two parameters representing its real and
160/// imaginary parts.
161#[derive(Clone, Serialize, Deserialize)]
162pub struct PiecewiseComplexScalar {
163    name: String,
164    variable: Box<dyn Variable>,
165    bins: usize,
166    range: (Float, Float),
167    re_ims: Vec<(ParameterLike, ParameterLike)>,
168    pids_re_im: Vec<(ParameterID, ParameterID)>,
169    bin_index: ScalarID,
170}
171impl PiecewiseComplexScalar {
172    /// Create a new [`PiecewiseComplexScalar`] with the given name and parameter value.
173    pub fn new<V: Variable + 'static>(
174        name: &str,
175        variable: &V,
176        bins: usize,
177        range: (Float, Float),
178        re_ims: Vec<(ParameterLike, ParameterLike)>,
179    ) -> Box<Self> {
180        assert_eq!(
181            bins,
182            re_ims.len(),
183            "Number of bins must match number of parameters!"
184        );
185        Self {
186            name: name.to_string(),
187            variable: dyn_clone::clone_box(variable),
188            bins,
189            range,
190            re_ims,
191            pids_re_im: Default::default(),
192            bin_index: Default::default(),
193        }
194        .into()
195    }
196}
197
198#[typetag::serde]
199impl Amplitude for PiecewiseComplexScalar {
200    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
201        self.pids_re_im = self
202            .re_ims
203            .iter()
204            .map(|(re, im)| {
205                (
206                    resources.register_parameter(re),
207                    resources.register_parameter(im),
208                )
209            })
210            .collect();
211        self.bin_index = resources.register_scalar(None);
212        resources.register_amplitude(&self.name)
213    }
214
215    fn precompute(&self, event: &Event, cache: &mut Cache) {
216        let maybe_bin_index = get_bin_index(self.variable.value(event), self.bins, self.range);
217        if let Some(bin_index) = maybe_bin_index {
218            cache.store_scalar(self.bin_index, bin_index as Float);
219        } else {
220            cache.store_scalar(self.bin_index, (self.bins + 1) as Float);
221            // store ibin = nbins + 1 if outside range
222        }
223    }
224
225    fn compute(&self, parameters: &Parameters, _event: &Event, cache: &Cache) -> Complex<Float> {
226        let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
227        if bin_index == self.bins + 1 {
228            Complex::ZERO
229        } else {
230            let pid_re_im = self.pids_re_im[bin_index];
231            Complex::new(parameters.get(pid_re_im.0), parameters.get(pid_re_im.1))
232        }
233    }
234
235    fn compute_gradient(
236        &self,
237        _parameters: &Parameters,
238        _event: &Event,
239        cache: &Cache,
240        gradient: &mut DVector<Complex<Float>>,
241    ) {
242        let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
243        if bin_index < self.bins + 1 {
244            let pid_re_im = self.pids_re_im[bin_index];
245            if let ParameterID::Parameter(ind) = pid_re_im.0 {
246                gradient[ind] = Complex::ONE;
247            }
248            if let ParameterID::Parameter(ind) = pid_re_im.1 {
249                gradient[ind] = Complex::I;
250            }
251        }
252    }
253}
254
255/// An Amplitude which represents a piecewise function of complex values
256///
257/// Parameters
258/// ----------
259/// name : str
260///     The Amplitude name
261/// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
262///     The variable to use for binning
263/// bins: usize
264///     The number of bins to use
265/// range: tuple of float
266///     The minimum and maximum bin edges
267/// values : list of tuple of ParameterLike
268///     The complex parameters contained in each bin of the Amplitude (each tuple contains the
269///     real and imaginary part of a single bin)
270///
271/// Returns
272/// -------
273/// laddu.Amplitude
274///     An Amplitude which can be registered by a laddu.Manager
275///
276/// Raises
277/// ------
278/// AssertionError
279///     If the number of bins does not match the number of parameters
280/// TypeError
281///     If the given `variable` is not a valid variable
282///
283/// See Also
284/// --------
285/// laddu.Manager
286/// laddu.Mass
287/// laddu.CosTheta
288/// laddu.Phi
289/// laddu.PolAngle
290/// laddu.PolMagnitude
291/// laddu.Mandelstam
292///
293#[cfg(feature = "python")]
294#[pyfunction(name = "PiecewiseComplexScalar")]
295pub fn py_piecewise_complex_scalar(
296    name: &str,
297    variable: Bound<'_, PyAny>,
298    bins: usize,
299    range: (Float, Float),
300    values: Vec<(PyParameterLike, PyParameterLike)>,
301) -> PyResult<PyAmplitude> {
302    let variable = variable.extract::<PyVariable>()?;
303    Ok(PyAmplitude(PiecewiseComplexScalar::new(
304        name,
305        &variable,
306        bins,
307        range,
308        values
309            .into_iter()
310            .map(|(value_re, value_im)| (value_re.0, value_im.0))
311            .collect(),
312    )))
313}
314
315/// A piecewise complex-valued [`Amplitude`] which just contains two parameters representing its magnitude and
316/// phase.
317#[derive(Clone, Serialize, Deserialize)]
318pub struct PiecewisePolarComplexScalar {
319    name: String,
320    variable: Box<dyn Variable>,
321    bins: usize,
322    range: (Float, Float),
323    r_thetas: Vec<(ParameterLike, ParameterLike)>,
324    pids_r_theta: Vec<(ParameterID, ParameterID)>,
325    bin_index: ScalarID,
326}
327impl PiecewisePolarComplexScalar {
328    /// Create a new [`PiecewiseComplexScalar`] with the given name and parameter value.
329    pub fn new<V: Variable + 'static>(
330        name: &str,
331        variable: &V,
332        bins: usize,
333        range: (Float, Float),
334        r_thetas: Vec<(ParameterLike, ParameterLike)>,
335    ) -> Box<Self> {
336        assert_eq!(
337            bins,
338            r_thetas.len(),
339            "Number of bins must match number of parameters!"
340        );
341        Self {
342            name: name.to_string(),
343            variable: dyn_clone::clone_box(variable),
344            bins,
345            range,
346            r_thetas,
347            pids_r_theta: Default::default(),
348            bin_index: Default::default(),
349        }
350        .into()
351    }
352}
353
354#[typetag::serde]
355impl Amplitude for PiecewisePolarComplexScalar {
356    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
357        self.pids_r_theta = self
358            .r_thetas
359            .iter()
360            .map(|(r, theta)| {
361                (
362                    resources.register_parameter(r),
363                    resources.register_parameter(theta),
364                )
365            })
366            .collect();
367        self.bin_index = resources.register_scalar(None);
368        resources.register_amplitude(&self.name)
369    }
370
371    fn precompute(&self, event: &Event, cache: &mut Cache) {
372        let maybe_bin_index = get_bin_index(self.variable.value(event), self.bins, self.range);
373        if let Some(bin_index) = maybe_bin_index {
374            cache.store_scalar(self.bin_index, bin_index as Float);
375        } else {
376            cache.store_scalar(self.bin_index, (self.bins + 1) as Float);
377            // store ibin = nbins + 1 if outside range
378        }
379    }
380
381    fn compute(&self, parameters: &Parameters, _event: &Event, cache: &Cache) -> Complex<Float> {
382        let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
383        if bin_index == self.bins + 1 {
384            Complex::ZERO
385        } else {
386            let pid_r_theta = self.pids_r_theta[bin_index];
387            Complex::from_polar(parameters.get(pid_r_theta.0), parameters.get(pid_r_theta.1))
388        }
389    }
390
391    fn compute_gradient(
392        &self,
393        parameters: &Parameters,
394        _event: &Event,
395        cache: &Cache,
396        gradient: &mut DVector<Complex<Float>>,
397    ) {
398        let bin_index: usize = cache.get_scalar(self.bin_index) as usize;
399        if bin_index < self.bins + 1 {
400            let pid_r_theta = self.pids_r_theta[bin_index];
401            let r = parameters.get(pid_r_theta.0);
402            let theta = parameters.get(pid_r_theta.1);
403            let exp_i_theta = Complex::cis(theta);
404            if let ParameterID::Parameter(ind) = pid_r_theta.0 {
405                gradient[ind] = exp_i_theta;
406            }
407            if let ParameterID::Parameter(ind) = pid_r_theta.1 {
408                gradient[ind] = Complex::<Float>::I * Complex::from_polar(r, theta);
409            }
410        }
411    }
412}
413
414/// An Amplitude which represents a piecewise function of polar complex values
415///
416/// Parameters
417/// ----------
418/// name : str
419///     The Amplitude name
420/// variable : {laddu.Mass, laddu.CosTheta, laddu.Phi, laddu.PolAngle, laddu.PolMagnitude, laddu.Mandelstam}
421///     The variable to use for binning
422/// bins: usize
423///     The number of bins to use
424/// range: tuple of float
425///     The minimum and maximum bin edges
426/// values : list of tuple of ParameterLike
427///     The polar complex parameters contained in each bin of the Amplitude (each tuple contains the
428///     magnitude and argument of a single bin)
429///
430/// Returns
431/// -------
432/// laddu.Amplitude
433///     An Amplitude which can be registered by a laddu.Manager
434///
435/// Raises
436/// ------
437/// AssertionError
438///     If the number of bins does not match the number of parameters
439/// TypeError
440///     If the given `variable` is not a valid variable
441///
442/// See Also
443/// --------
444/// laddu.Manager
445/// laddu.Mass
446/// laddu.CosTheta
447/// laddu.Phi
448/// laddu.PolAngle
449/// laddu.PolMagnitude
450/// laddu.Mandelstam
451///
452#[cfg(feature = "python")]
453#[pyfunction(name = "PiecewisePolarComplexScalar")]
454pub fn py_piecewise_polar_complex_scalar(
455    name: &str,
456    variable: Bound<'_, PyAny>,
457    bins: usize,
458    range: (Float, Float),
459    values: Vec<(PyParameterLike, PyParameterLike)>,
460) -> PyResult<PyAmplitude> {
461    let variable = variable.extract::<PyVariable>()?;
462    Ok(PyAmplitude(PiecewisePolarComplexScalar::new(
463        name,
464        &variable,
465        bins,
466        range,
467        values
468            .into_iter()
469            .map(|(value_re, value_im)| (value_re.0, value_im.0))
470            .collect(),
471    )))
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use approx::assert_relative_eq;
478    use laddu_core::{data::test_dataset, parameter, Manager, Mass, PI};
479    use std::sync::Arc;
480
481    #[test]
482    fn test_piecewise_scalar_creation_and_evaluation() {
483        let mut manager = Manager::default();
484        let v = Mass::new([2]);
485        let amp = PiecewiseScalar::new(
486            "test_scalar",
487            &v,
488            3,
489            (0.0, 1.0),
490            vec![
491                parameter("test_param0"),
492                parameter("test_param1"),
493                parameter("test_param2"),
494            ],
495        );
496        let aid = manager.register(amp).unwrap();
497
498        let dataset = Arc::new(test_dataset());
499        let expr = aid.into(); // Direct amplitude evaluation
500        let model = manager.model(&expr);
501        let evaluator = model.load(&dataset);
502
503        let params = vec![1.1, 2.2, 3.3];
504        let result = evaluator.evaluate(&params);
505
506        assert_relative_eq!(result[0].re, 2.2);
507        assert_relative_eq!(result[0].im, 0.0);
508    }
509
510    #[test]
511    fn test_piecewise_scalar_gradient() {
512        let mut manager = Manager::default();
513        let v = Mass::new([2]);
514        let amp = PiecewiseScalar::new(
515            "test_scalar",
516            &v,
517            3,
518            (0.0, 1.0),
519            vec![
520                parameter("test_param0"),
521                parameter("test_param1"),
522                parameter("test_param2"),
523            ],
524        );
525        let aid = manager.register(amp).unwrap();
526
527        let dataset = Arc::new(test_dataset());
528        let expr = aid.norm_sqr(); // |f(x)|^2
529        let model = manager.model(&expr);
530        let evaluator = model.load(&dataset);
531
532        let params = vec![1.0, 2.0, 3.0];
533        let gradient = evaluator.evaluate_gradient(&params);
534
535        // For |f(x)|^2 where f(x) = x, the derivative should be 2x
536        assert_relative_eq!(gradient[0][0].re, 0.0);
537        assert_relative_eq!(gradient[0][0].im, 0.0);
538        assert_relative_eq!(gradient[0][1].re, 4.0);
539        assert_relative_eq!(gradient[0][1].im, 0.0);
540        assert_relative_eq!(gradient[0][2].re, 0.0);
541        assert_relative_eq!(gradient[0][2].im, 0.0);
542    }
543
544    #[test]
545    fn test_piecewise_complex_scalar_evaluation() {
546        let mut manager = Manager::default();
547        let v = Mass::new([2]);
548        let amp = PiecewiseComplexScalar::new(
549            "test_complex",
550            &v,
551            3,
552            (0.0, 1.0),
553            vec![
554                (parameter("re_param0"), parameter("im_param0")),
555                (parameter("re_param1"), parameter("im_param1")),
556                (parameter("re_param2"), parameter("im_param2")),
557            ],
558        );
559        let aid = manager.register(amp).unwrap();
560
561        let dataset = Arc::new(test_dataset());
562        let expr = aid.into();
563        let model = manager.model(&expr);
564        let evaluator = model.load(&dataset);
565
566        let params = vec![1.1, 1.2, 2.1, 2.2, 3.1, 3.2]; // Real and imaginary parts
567        let result = evaluator.evaluate(&params);
568
569        assert_relative_eq!(result[0].re, 2.1);
570        assert_relative_eq!(result[0].im, 2.2);
571    }
572
573    #[test]
574    fn test_piecewise_complex_scalar_gradient() {
575        let mut manager = Manager::default();
576        let v = Mass::new([2]);
577        let amp = PiecewiseComplexScalar::new(
578            "test_complex",
579            &v,
580            3,
581            (0.0, 1.0),
582            vec![
583                (parameter("re_param0"), parameter("im_param0")),
584                (parameter("re_param1"), parameter("im_param1")),
585                (parameter("re_param2"), parameter("im_param2")),
586            ],
587        );
588        let aid = manager.register(amp).unwrap();
589
590        let dataset = Arc::new(test_dataset());
591        let expr = aid.norm_sqr(); // |f(x + iy)|^2
592        let model = manager.model(&expr);
593        let evaluator = model.load(&dataset);
594
595        let params = vec![1.1, 1.2, 2.1, 2.2, 3.1, 3.2]; // Real and imaginary parts
596        let gradient = evaluator.evaluate_gradient(&params);
597
598        // For |f(x + iy)|^2, partial derivatives should be 2x and 2y
599        assert_relative_eq!(gradient[0][0].re, 0.0);
600        assert_relative_eq!(gradient[0][0].im, 0.0);
601        assert_relative_eq!(gradient[0][1].re, 0.0);
602        assert_relative_eq!(gradient[0][1].im, 0.0);
603        assert_relative_eq!(gradient[0][2].re, 4.2);
604        assert_relative_eq!(gradient[0][2].im, 0.0);
605        assert_relative_eq!(gradient[0][3].re, 4.4);
606        assert_relative_eq!(gradient[0][3].im, 0.0);
607        assert_relative_eq!(gradient[0][4].re, 0.0);
608        assert_relative_eq!(gradient[0][4].im, 0.0);
609        assert_relative_eq!(gradient[0][5].re, 0.0);
610        assert_relative_eq!(gradient[0][5].im, 0.0);
611    }
612
613    #[test]
614    fn test_piecewise_polar_complex_scalar_evaluation() {
615        let mut manager = Manager::default();
616        let v = Mass::new([2]);
617        let amp = PiecewisePolarComplexScalar::new(
618            "test_polar",
619            &v,
620            3,
621            (0.0, 1.0),
622            vec![
623                (parameter("r_param0"), parameter("theta_param0")),
624                (parameter("r_param1"), parameter("theta_param1")),
625                (parameter("r_param2"), parameter("theta_param2")),
626            ],
627        );
628        let aid = manager.register(amp).unwrap();
629
630        let dataset = Arc::new(test_dataset());
631        let expr = aid.into();
632        let model = manager.model(&expr);
633        let evaluator = model.load(&dataset);
634
635        let r = 2.0;
636        let theta = PI / 4.3;
637        let params = vec![
638            1.1 * r,
639            1.2 * theta,
640            2.1 * r,
641            2.2 * theta,
642            3.1 * r,
643            3.2 * theta,
644        ];
645        let result = evaluator.evaluate(&params);
646
647        // r * (cos(theta) + i*sin(theta))
648        assert_relative_eq!(result[0].re, 2.1 * r * (2.2 * theta).cos());
649        assert_relative_eq!(result[0].im, 2.1 * r * (2.2 * theta).sin());
650    }
651
652    #[test]
653    fn test_piecewise_polar_complex_scalar_gradient() {
654        let mut manager = Manager::default();
655        let v = Mass::new([2]);
656        let amp = PiecewisePolarComplexScalar::new(
657            "test_polar",
658            &v,
659            3,
660            (0.0, 1.0),
661            vec![
662                (parameter("r_param0"), parameter("theta_param0")),
663                (parameter("r_param1"), parameter("theta_param1")),
664                (parameter("r_param2"), parameter("theta_param2")),
665            ],
666        );
667        let aid = manager.register(amp).unwrap();
668
669        let dataset = Arc::new(test_dataset());
670        let expr = aid.into(); // f(r,θ) = re^(iθ)
671        let model = manager.model(&expr);
672        let evaluator = model.load(&dataset);
673
674        let r = 2.0;
675        let theta = PI / 4.3;
676        let params = vec![
677            1.1 * r,
678            1.2 * theta,
679            2.1 * r,
680            2.2 * theta,
681            3.1 * r,
682            3.2 * theta,
683        ];
684        let gradient = evaluator.evaluate_gradient(&params);
685
686        // d/dr re^(iθ) = e^(iθ), d/dθ re^(iθ) = ire^(iθ)
687        assert_relative_eq!(gradient[0][0].re, 0.0);
688        assert_relative_eq!(gradient[0][0].im, 0.0);
689        assert_relative_eq!(gradient[0][1].re, 0.0);
690        assert_relative_eq!(gradient[0][1].im, 0.0);
691        assert_relative_eq!(gradient[0][2].re, Float::cos(2.2 * theta));
692        assert_relative_eq!(gradient[0][2].im, Float::sin(2.2 * theta));
693        assert_relative_eq!(gradient[0][3].re, -2.1 * r * Float::sin(2.2 * theta));
694        assert_relative_eq!(gradient[0][3].im, 2.1 * r * Float::cos(2.2 * theta));
695        assert_relative_eq!(gradient[0][4].re, 0.0);
696        assert_relative_eq!(gradient[0][4].im, 0.0);
697        assert_relative_eq!(gradient[0][5].re, 0.0);
698        assert_relative_eq!(gradient[0][5].im, 0.0);
699    }
700}