laddu_core/
amplitudes.rs

1use std::{
2    fmt::{Debug, Display},
3    sync::Arc,
4};
5
6use auto_ops::*;
7use dyn_clone::DynClone;
8use nalgebra::{ComplexField, DVector};
9use num::Complex;
10
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use crate::{
17    data::{Dataset, Event},
18    resources::{Cache, Parameters, Resources},
19    Float, LadduError,
20};
21
22#[cfg(feature = "mpi")]
23use crate::mpi::LadduMPI;
24
25#[cfg(feature = "mpi")]
26use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
27
28/// An enum containing either a named free parameter or a constant value.
29#[derive(Clone, Default, Serialize, Deserialize)]
30pub enum ParameterLike {
31    /// A named free parameter.
32    Parameter(String),
33    /// A constant value.
34    Constant(Float),
35    /// An uninitialized parameter-like structure (typically used as the value given in an
36    /// [`Amplitude`] constructor before the [`Amplitude::register`] method is called).
37    #[default]
38    Uninit,
39}
40
41/// Shorthand for generating a named free parameter.
42pub fn parameter(name: &str) -> ParameterLike {
43    ParameterLike::Parameter(name.to_string())
44}
45
46/// Shorthand for generating a constant value (which acts like a fixed parameter).
47pub fn constant(value: Float) -> ParameterLike {
48    ParameterLike::Constant(value)
49}
50
51/// This is the only required trait for writing new amplitude-like structures for this
52/// crate. Users need only implement the [`register`](Amplitude::register)
53/// method to register parameters, cached values, and the amplitude itself with an input
54/// [`Resources`] struct and the [`compute`](Amplitude::compute) method to actually carry
55/// out the calculation. [`Amplitude`]-implementors are required to implement [`Clone`] and can
56/// optionally implement a [`precompute`](Amplitude::precompute) method to calculate and
57/// cache values which do not depend on free parameters.
58#[typetag::serde(tag = "type")]
59pub trait Amplitude: DynClone + Send + Sync {
60    /// This method should be used to tell the [`Resources`] manager about all of
61    /// the free parameters and cached values used by this [`Amplitude`]. It should end by
62    /// returning an [`AmplitudeID`], which can be obtained from the
63    /// [`Resources::register_amplitude`] method.
64    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
65    /// This method can be used to do some critical calculations ahead of time and
66    /// store them in a [`Cache`]. These values can only depend on the data in an [`Event`],
67    /// not on any free parameters in the fit. This method is opt-in since it is not required
68    /// to make a functioning [`Amplitude`].
69    #[allow(unused_variables)]
70    fn precompute(&self, event: &Event, cache: &mut Cache) {}
71    /// Evaluates [`Amplitude::precompute`] over ever [`Event`] in a [`Dataset`].
72    #[cfg(feature = "rayon")]
73    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
74        dataset
75            .events
76            .par_iter()
77            .zip(resources.caches.par_iter_mut())
78            .for_each(|(event, cache)| {
79                self.precompute(event, cache);
80            })
81    }
82    /// Evaluates [`Amplitude::precompute`] over ever [`Event`] in a [`Dataset`].
83    #[cfg(not(feature = "rayon"))]
84    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
85        dataset
86            .events
87            .iter()
88            .zip(resources.caches.iter_mut())
89            .for_each(|(event, cache)| self.precompute(event, cache))
90    }
91    /// This method constitutes the main machinery of an [`Amplitude`], returning the actual
92    /// calculated value for a particular [`Event`] and set of [`Parameters`]. See those
93    /// structs, as well as [`Cache`], for documentation on their available methods. For the
94    /// most part, [`Event`]s can be interacted with via
95    /// [`Variable`](crate::utils::variables::Variable)s, while [`Parameters`] and the
96    /// [`Cache`] are more like key-value storage accessed by
97    /// [`ParameterID`](crate::resources::ParameterID)s and several different types of cache
98    /// IDs.
99    fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
100
101    /// This method yields the gradient of a particular [`Amplitude`] at a point specified
102    /// by a particular [`Event`] and set of [`Parameters`]. See those structs, as well as
103    /// [`Cache`], for documentation on their available methods. For the most part,
104    /// [`Event`]s can be interacted with via [`Variable`](crate::utils::variables::Variable)s,
105    /// while [`Parameters`] and the [`Cache`] are more like key-value storage accessed by
106    /// [`ParameterID`](crate::resources::ParameterID)s and several different types of cache
107    /// IDs. If the analytic version of the gradient is known, this method can be overwritten to
108    /// improve performance for some derivative-using methods of minimization. The default
109    /// implementation calculates a central finite difference across all parameters, regardless of
110    /// whether or not they are used in the [`Amplitude`].
111    ///
112    /// In the future, it may be possible to automatically implement this with the indices of
113    /// registered free parameters, but until then, the [`Amplitude::central_difference_with_indices`]
114    /// method can be used to conveniently only calculate central differences for the parameters
115    /// which are used by the [`Amplitude`].
116    fn compute_gradient(
117        &self,
118        parameters: &Parameters,
119        event: &Event,
120        cache: &Cache,
121        gradient: &mut DVector<Complex<Float>>,
122    ) {
123        self.central_difference_with_indices(
124            &Vec::from_iter(0..parameters.len()),
125            parameters,
126            event,
127            cache,
128            gradient,
129        )
130    }
131
132    /// A helper function to implement a central difference only on indices which correspond to
133    /// free parameters in the [`Amplitude`]. For example, if an [`Amplitude`] contains free
134    /// parameters registered to indices 1, 3, and 5 of the its internal parameters array, then
135    /// running this with those indices will compute a central finite difference derivative for
136    /// those coordinates only, since the rest can be safely assumed to be zero.
137    fn central_difference_with_indices(
138        &self,
139        indices: &[usize],
140        parameters: &Parameters,
141        event: &Event,
142        cache: &Cache,
143        gradient: &mut DVector<Complex<Float>>,
144    ) {
145        let x = parameters.parameters.to_owned();
146        let constants = parameters.constants.to_owned();
147        let h: DVector<Float> = x
148            .iter()
149            .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
150            .collect::<Vec<_>>()
151            .into();
152        for i in indices {
153            let mut x_plus = x.clone();
154            let mut x_minus = x.clone();
155            x_plus[*i] += h[*i];
156            x_minus[*i] -= h[*i];
157            let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
158            let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
159            gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
160        }
161    }
162}
163
164/// Utility function to calculate a central finite difference gradient.
165pub fn central_difference<F: Fn(&[Float]) -> Float>(
166    parameters: &[Float],
167    func: F,
168) -> DVector<Float> {
169    let mut gradient = DVector::zeros(parameters.len());
170    let x = parameters.to_owned();
171    let h: DVector<Float> = x
172        .iter()
173        .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
174        .collect::<Vec<_>>()
175        .into();
176    for i in 0..parameters.len() {
177        let mut x_plus = x.clone();
178        let mut x_minus = x.clone();
179        x_plus[i] += h[i];
180        x_minus[i] -= h[i];
181        let f_plus = func(&x_plus);
182        let f_minus = func(&x_minus);
183        gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
184    }
185    gradient
186}
187
188dyn_clone::clone_trait_object!(Amplitude);
189
190/// A helper struct that contains the value of each amplitude for a particular event
191#[derive(Debug)]
192pub struct AmplitudeValues(pub Vec<Complex<Float>>);
193
194/// A helper struct that contains the gradient of each amplitude for a particular event
195#[derive(Debug)]
196pub struct GradientValues(pub usize, pub Vec<DVector<Complex<Float>>>);
197
198/// A tag which refers to a registered [`Amplitude`]. This is the base object which can be used to
199/// build [`Expression`]s and should be obtained from the [`Manager::register`] method.
200#[derive(Clone, Default, Debug, Serialize, Deserialize)]
201pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
202
203impl Display for AmplitudeID {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        write!(f, "{}(id={})", self.0, self.1)
206    }
207}
208
209impl From<AmplitudeID> for Expression {
210    fn from(value: AmplitudeID) -> Self {
211        Self::Amp(value)
212    }
213}
214
215/// An expression tree which contains [`AmplitudeID`]s and operators over them.
216#[derive(Clone, Serialize, Deserialize, Default)]
217pub enum Expression {
218    #[default]
219    /// A expression equal to zero.
220    Zero,
221    /// A expression equal to one.
222    One,
223    /// A registered [`Amplitude`] referenced by an [`AmplitudeID`].
224    Amp(AmplitudeID),
225    /// The sum of two [`Expression`]s.
226    Add(Box<Expression>, Box<Expression>),
227    /// The product of two [`Expression`]s.
228    Mul(Box<Expression>, Box<Expression>),
229    /// The real part of an [`Expression`].
230    Real(Box<Expression>),
231    /// The imaginary part of an [`Expression`].
232    Imag(Box<Expression>),
233    /// The absolute square of an [`Expression`].
234    NormSqr(Box<Expression>),
235}
236
237impl Debug for Expression {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        self.write_tree(f, "", "", "")
240    }
241}
242
243impl Display for Expression {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        write!(f, "{:?}", self)
246    }
247}
248
249impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression { Expression::Add(Box::new(a.clone()), Box::new(b.clone()))});
250impl_op_ex!(*|a: &Expression, b: &Expression| -> Expression {
251    Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
252});
253impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))});
254impl_op_ex_commutative!(*|a: &AmplitudeID, b: &Expression| -> Expression {
255    Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
256});
257impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(Expression::Amp(b.clone())))});
258impl_op_ex!(*|a: &AmplitudeID, b: &AmplitudeID| -> Expression {
259    Expression::Mul(
260        Box::new(Expression::Amp(a.clone())),
261        Box::new(Expression::Amp(b.clone())),
262    )
263});
264
265impl AmplitudeID {
266    /// Takes the real part of the given [`Amplitude`].
267    pub fn real(&self) -> Expression {
268        Expression::Real(Box::new(Expression::Amp(self.clone())))
269    }
270    /// Takes the imaginary part of the given [`Amplitude`].
271    pub fn imag(&self) -> Expression {
272        Expression::Imag(Box::new(Expression::Amp(self.clone())))
273    }
274    /// Takes the absolute square of the given [`Amplitude`].
275    pub fn norm_sqr(&self) -> Expression {
276        Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
277    }
278}
279
280impl Expression {
281    /// Evaluate an [`Expression`] over a single event using calculated [`AmplitudeValues`]
282    ///
283    /// This method parses the underlying [`Expression`] but doesn't actually calculate the values
284    /// from the [`Amplitude`]s themselves.
285    pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
286        match self {
287            Expression::Amp(aid) => amplitude_values.0[aid.1],
288            Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
289            Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
290            Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
291            Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
292            Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
293            Expression::Zero => Complex::ZERO,
294            Expression::One => Complex::ONE,
295        }
296    }
297    /// Evaluate the gradient of an [`Expression`] over a single event using calculated [`AmplitudeValues`]
298    ///
299    /// This method parses the underlying [`Expression`] but doesn't actually calculate the
300    /// gradient from the [`Amplitude`]s themselves.
301    pub fn evaluate_gradient(
302        &self,
303        amplitude_values: &AmplitudeValues,
304        gradient_values: &GradientValues,
305    ) -> DVector<Complex<Float>> {
306        match self {
307            Expression::Amp(aid) => gradient_values.1[aid.1].clone(),
308            Expression::Add(a, b) => {
309                a.evaluate_gradient(amplitude_values, gradient_values)
310                    + b.evaluate_gradient(amplitude_values, gradient_values)
311            }
312            Expression::Mul(a, b) => {
313                let f_a = a.evaluate(amplitude_values);
314                let f_b = b.evaluate(amplitude_values);
315                b.evaluate_gradient(amplitude_values, gradient_values)
316                    .map(|g| g * f_a)
317                    + a.evaluate_gradient(amplitude_values, gradient_values)
318                        .map(|g| g * f_b)
319            }
320            Expression::Real(a) => a
321                .evaluate_gradient(amplitude_values, gradient_values)
322                .map(|g| Complex::new(g.re, 0.0)),
323            Expression::Imag(a) => a
324                .evaluate_gradient(amplitude_values, gradient_values)
325                .map(|g| Complex::new(g.im, 0.0)),
326            Expression::NormSqr(a) => {
327                let conj_f_a = a.evaluate(amplitude_values).conjugate();
328                a.evaluate_gradient(amplitude_values, gradient_values)
329                    .map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
330            }
331            Expression::Zero | Expression::One => DVector::zeros(gradient_values.0),
332        }
333    }
334    /// Takes the real part of the given [`Expression`].
335    pub fn real(&self) -> Self {
336        Self::Real(Box::new(self.clone()))
337    }
338    /// Takes the imaginary part of the given [`Expression`].
339    pub fn imag(&self) -> Self {
340        Self::Imag(Box::new(self.clone()))
341    }
342    /// Takes the absolute square of the given [`Expression`].
343    pub fn norm_sqr(&self) -> Self {
344        Self::NormSqr(Box::new(self.clone()))
345    }
346
347    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
348    fn write_tree(
349        &self,
350        f: &mut std::fmt::Formatter<'_>,
351        parent_prefix: &str,
352        immediate_prefix: &str,
353        parent_suffix: &str,
354    ) -> std::fmt::Result {
355        let display_string = match self {
356            Self::Amp(aid) => aid.to_string(),
357            Self::Add(_, _) => "+".to_string(),
358            Self::Mul(_, _) => "*".to_string(),
359            Self::Real(_) => "Re".to_string(),
360            Self::Imag(_) => "Im".to_string(),
361            Self::NormSqr(_) => "NormSqr".to_string(),
362            Self::Zero => "0".to_string(),
363            Self::One => "1".to_string(),
364        };
365        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
366        match self {
367            Self::Amp(_) | Self::Zero | Self::One => {}
368            Self::Add(a, b) | Self::Mul(a, b) => {
369                let terms = [a, b];
370                let mut it = terms.iter().peekable();
371                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
372                while let Some(child) = it.next() {
373                    match it.peek() {
374                        Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│  "),
375                        None => child.write_tree(f, &child_prefix, "└─ ", "   "),
376                    }?;
377                }
378            }
379            Self::Real(a) | Self::Imag(a) | Self::NormSqr(a) => {
380                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
381                a.write_tree(f, &child_prefix, "└─ ", "   ")?;
382            }
383        }
384        Ok(())
385    }
386}
387
388/// A manager which can be used to register [`Amplitude`]s with [`Resources`]. This structure is
389/// essential to any analysis and should be constructed using the [`Manager::default()`] method.
390#[derive(Default, Clone, Serialize, Deserialize)]
391pub struct Manager {
392    amplitudes: Vec<Box<dyn Amplitude>>,
393    resources: Resources,
394}
395
396impl Manager {
397    /// Get the list of parameter names in the order they appear in the [`Manager`]'s [`Resources`] field.
398    pub fn parameters(&self) -> Vec<String> {
399        self.resources.parameters.iter().cloned().collect()
400    }
401    /// Register the given [`Amplitude`] and return an [`AmplitudeID`] that can be used to build
402    /// [`Expression`]s.
403    ///
404    /// # Errors
405    ///
406    /// The [`Amplitude`]'s name must be unique and not already
407    /// registered, else this will return a [`RegistrationError`][LadduError::RegistrationError].
408    pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
409        let mut amp = amplitude.clone();
410        let aid = amp.register(&mut self.resources)?;
411        self.amplitudes.push(amp);
412        Ok(aid)
413    }
414    /// Turns an [`Expression`] made from registered [`Amplitude`]s into a [`Model`].
415    pub fn model(&self, expression: &Expression) -> Model {
416        Model {
417            manager: self.clone(),
418            expression: expression.clone(),
419        }
420    }
421}
422
423/// A struct which contains a set of registerd [`Amplitude`]s (inside a [`Manager`])
424/// and an [`Expression`].
425///
426/// This struct implements [`serde::Serialize`] and [`serde::Deserialize`] and is intended
427/// to be used to store models to disk.
428#[derive(Clone, Serialize, Deserialize)]
429pub struct Model {
430    pub(crate) manager: Manager,
431    pub(crate) expression: Expression,
432}
433
434impl Model {
435    /// Get the list of parameter names in the order they appear in the [`Model`]'s [`Manager`] field.
436    pub fn parameters(&self) -> Vec<String> {
437        self.manager.parameters()
438    }
439    /// Create an [`Evaluator`] which can compute the result of the internal [`Expression`] built on
440    /// registered [`Amplitude`]s over the given [`Dataset`]. This method precomputes any relevant
441    /// information over the [`Event`]s in the [`Dataset`].
442    pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
443        let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
444        loaded_resources.write().reserve_cache(dataset.n_events());
445        for amplitude in &self.manager.amplitudes {
446            amplitude.precompute_all(dataset, &mut loaded_resources.write());
447        }
448        Evaluator {
449            amplitudes: self.manager.amplitudes.clone(),
450            resources: loaded_resources.clone(),
451            dataset: dataset.clone(),
452            expression: self.expression.clone(),
453        }
454    }
455}
456
457/// A structure which can be used to evaluate the stored [`Expression`] built on registered
458/// [`Amplitude`]s. This contains a [`Resources`] struct which already contains cached values for
459/// precomputed [`Amplitude`]s and any relevant free parameters and constants.
460#[derive(Clone)]
461pub struct Evaluator {
462    /// A list of [`Amplitude`]s which were registered with the [`Manager`] used to create the
463    /// internal [`Expression`]. This includes but is not limited to those which are actually used
464    /// in the [`Expression`].
465    pub amplitudes: Vec<Box<dyn Amplitude>>,
466    /// The internal [`Resources`] where precalculated values are stored
467    pub resources: Arc<RwLock<Resources>>,
468    /// The internal [`Dataset`]
469    pub dataset: Arc<Dataset>,
470    /// The internal [`Expression`]
471    pub expression: Expression,
472}
473
474impl Evaluator {
475    /// Get the list of parameter names in the order they appear in the [`Evaluator::evaluate`]
476    /// method.
477    pub fn parameters(&self) -> Vec<String> {
478        self.resources.read().parameters.iter().cloned().collect()
479    }
480    /// Activate an [`Amplitude`] by name.
481    pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
482        self.resources.write().activate(name)
483    }
484    /// Activate several [`Amplitude`]s by name.
485    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
486        self.resources.write().activate_many(names)
487    }
488    /// Activate all registered [`Amplitude`]s.
489    pub fn activate_all(&self) {
490        self.resources.write().activate_all();
491    }
492    /// Dectivate an [`Amplitude`] by name.
493    pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
494        self.resources.write().deactivate(name)
495    }
496    /// Deactivate several [`Amplitude`]s by name.
497    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
498        self.resources.write().deactivate_many(names)
499    }
500    /// Deactivate all registered [`Amplitude`]s.
501    pub fn deactivate_all(&self) {
502        self.resources.write().deactivate_all();
503    }
504    /// Isolate an [`Amplitude`] by name (deactivate the rest).
505    pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
506        self.resources.write().isolate(name)
507    }
508    /// Isolate several [`Amplitude`]s by name (deactivate the rest).
509    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
510        self.resources.write().isolate_many(names)
511    }
512
513    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
514    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
515    ///
516    /// # Notes
517    ///
518    /// This method is not intended to be called in analyses but rather in writing methods
519    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
520    pub fn evaluate_local(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
521        let resources = self.resources.read();
522        let parameters = Parameters::new(parameters, &resources.constants);
523        #[cfg(feature = "rayon")]
524        {
525            let amplitude_values_vec: Vec<AmplitudeValues> = self
526                .dataset
527                .events
528                .par_iter()
529                .zip(resources.caches.par_iter())
530                .map(|(event, cache)| {
531                    AmplitudeValues(
532                        self.amplitudes
533                            .iter()
534                            .zip(resources.active.iter())
535                            .map(|(amp, active)| {
536                                if *active {
537                                    amp.compute(&parameters, event, cache)
538                                } else {
539                                    Complex::new(0.0, 0.0)
540                                }
541                            })
542                            .collect(),
543                    )
544                })
545                .collect();
546            amplitude_values_vec
547                .par_iter()
548                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
549                .collect()
550        }
551        #[cfg(not(feature = "rayon"))]
552        {
553            let amplitude_values_vec: Vec<AmplitudeValues> = self
554                .dataset
555                .events
556                .iter()
557                .zip(resources.caches.iter())
558                .map(|(event, cache)| {
559                    AmplitudeValues(
560                        self.amplitudes
561                            .iter()
562                            .zip(resources.active.iter())
563                            .map(|(amp, active)| {
564                                if *active {
565                                    amp.compute(&parameters, event, cache)
566                                } else {
567                                    Complex::new(0.0, 0.0)
568                                }
569                            })
570                            .collect(),
571                    )
572                })
573                .collect();
574            amplitude_values_vec
575                .iter()
576                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
577                .collect()
578        }
579    }
580
581    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
582    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
583    ///
584    /// # Notes
585    ///
586    /// This method is not intended to be called in analyses but rather in writing methods
587    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
588    #[cfg(feature = "mpi")]
589    fn evaluate_mpi(
590        &self,
591        parameters: &[Float],
592        world: &SimpleCommunicator,
593    ) -> Vec<Complex<Float>> {
594        let local_evaluation = self.evaluate_local(parameters);
595        let n_events = self.dataset.n_events();
596        let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; n_events];
597        let (counts, displs) = world.get_counts_displs(n_events);
598        {
599            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
600            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
601        }
602        buffer
603    }
604
605    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
606    /// [`Evaluator`] with the given values for free parameters.
607    pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
608        #[cfg(feature = "mpi")]
609        {
610            if let Some(world) = crate::mpi::get_world() {
611                return self.evaluate_mpi(parameters, &world);
612            }
613        }
614        self.evaluate_local(parameters)
615    }
616
617    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
618    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
619    ///
620    /// # Notes
621    ///
622    /// This method is not intended to be called in analyses but rather in writing methods
623    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
624    pub fn evaluate_gradient_local(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
625        let resources = self.resources.read();
626        let parameters = Parameters::new(parameters, &resources.constants);
627        #[cfg(feature = "rayon")]
628        {
629            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
630                .dataset
631                .events
632                .par_iter()
633                .zip(resources.caches.par_iter())
634                .map(|(event, cache)| {
635                    let mut gradient_values =
636                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
637                    self.amplitudes
638                        .iter()
639                        .zip(resources.active.iter())
640                        .zip(gradient_values.iter_mut())
641                        .for_each(|((amp, active), grad)| {
642                            if *active {
643                                amp.compute_gradient(&parameters, event, cache, grad)
644                            }
645                        });
646                    (
647                        AmplitudeValues(
648                            self.amplitudes
649                                .iter()
650                                .zip(resources.active.iter())
651                                .map(|(amp, active)| {
652                                    if *active {
653                                        amp.compute(&parameters, event, cache)
654                                    } else {
655                                        Complex::new(0.0, 0.0)
656                                    }
657                                })
658                                .collect(),
659                        ),
660                        GradientValues(parameters.len(), gradient_values),
661                    )
662                })
663                .collect();
664            amplitude_values_and_gradient_vec
665                .par_iter()
666                .map(|(amplitude_values, gradient_values)| {
667                    self.expression
668                        .evaluate_gradient(amplitude_values, gradient_values)
669                })
670                .collect()
671        }
672        #[cfg(not(feature = "rayon"))]
673        {
674            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
675                .dataset
676                .events
677                .iter()
678                .zip(resources.caches.iter())
679                .map(|(event, cache)| {
680                    let mut gradient_values =
681                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
682                    self.amplitudes
683                        .iter()
684                        .zip(resources.active.iter())
685                        .zip(gradient_values.iter_mut())
686                        .for_each(|((amp, active), grad)| {
687                            if *active {
688                                amp.compute_gradient(&parameters, event, cache, grad)
689                            }
690                        });
691                    (
692                        AmplitudeValues(
693                            self.amplitudes
694                                .iter()
695                                .zip(resources.active.iter())
696                                .map(|(amp, active)| {
697                                    if *active {
698                                        amp.compute(&parameters, event, cache)
699                                    } else {
700                                        Complex::new(0.0, 0.0)
701                                    }
702                                })
703                                .collect(),
704                        ),
705                        GradientValues(parameters.len(), gradient_values),
706                    )
707                })
708                .collect();
709
710            amplitude_values_and_gradient_vec
711                .iter()
712                .map(|(amplitude_values, gradient_values)| {
713                    self.expression
714                        .evaluate_gradient(amplitude_values, gradient_values)
715                })
716                .collect()
717        }
718    }
719
720    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
721    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
722    ///
723    /// # Notes
724    ///
725    /// This method is not intended to be called in analyses but rather in writing methods
726    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
727    #[cfg(feature = "mpi")]
728    fn evaluate_gradient_mpi(
729        &self,
730        parameters: &[Float],
731        world: &SimpleCommunicator,
732    ) -> Vec<DVector<Complex<Float>>> {
733        let flattened_local_evaluation = self
734            .evaluate_gradient_local(parameters)
735            .iter()
736            .flat_map(|g| g.data.as_vec().to_vec())
737            .collect::<Vec<Complex<Float>>>();
738        let n_events = self.dataset.n_events();
739        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
740        let mut flattened_result_buffer = vec![Complex::ZERO; n_events * parameters.len()];
741        let mut partitioned_flattened_result_buffer =
742            PartitionMut::new(&mut flattened_result_buffer, counts, displs);
743        world.all_gather_varcount_into(
744            &flattened_local_evaluation,
745            &mut partitioned_flattened_result_buffer,
746        );
747        flattened_result_buffer
748            .chunks(parameters.len())
749            .map(DVector::from_row_slice)
750            .collect()
751    }
752
753    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
754    /// [`Evaluator`] with the given values for free parameters.
755    pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
756        #[cfg(feature = "mpi")]
757        {
758            if let Some(world) = crate::mpi::get_world() {
759                return self.evaluate_gradient_mpi(parameters, &world);
760            }
761        }
762        self.evaluate_gradient_local(parameters)
763    }
764}
765
766#[cfg(test)]
767mod tests {
768    use crate::data::{test_dataset, test_event};
769
770    use super::*;
771    use crate::{
772        data::Event,
773        resources::{Cache, ParameterID, Parameters, Resources},
774        Complex, DVector, Float, LadduError,
775    };
776    use approx::assert_relative_eq;
777    use serde::{Deserialize, Serialize};
778
779    #[derive(Clone, Serialize, Deserialize)]
780    pub struct ComplexScalar {
781        name: String,
782        re: ParameterLike,
783        pid_re: ParameterID,
784        im: ParameterLike,
785        pid_im: ParameterID,
786    }
787
788    impl ComplexScalar {
789        pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
790            Self {
791                name: name.to_string(),
792                re,
793                pid_re: Default::default(),
794                im,
795                pid_im: Default::default(),
796            }
797            .into()
798        }
799    }
800
801    #[typetag::serde]
802    impl Amplitude for ComplexScalar {
803        fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
804            self.pid_re = resources.register_parameter(&self.re);
805            self.pid_im = resources.register_parameter(&self.im);
806            resources.register_amplitude(&self.name)
807        }
808
809        fn compute(
810            &self,
811            parameters: &Parameters,
812            _event: &Event,
813            _cache: &Cache,
814        ) -> Complex<Float> {
815            Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
816        }
817
818        fn compute_gradient(
819            &self,
820            _parameters: &Parameters,
821            _event: &Event,
822            _cache: &Cache,
823            gradient: &mut DVector<Complex<Float>>,
824        ) {
825            if let ParameterID::Parameter(ind) = self.pid_re {
826                gradient[ind] = Complex::ONE;
827            }
828            if let ParameterID::Parameter(ind) = self.pid_im {
829                gradient[ind] = Complex::I;
830            }
831        }
832    }
833
834    #[test]
835    fn test_constant_amplitude() {
836        let mut manager = Manager::default();
837        let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
838        let aid = manager.register(amp).unwrap();
839        let dataset = Arc::new(Dataset {
840            events: vec![Arc::new(test_event())],
841        });
842        let expr = Expression::Amp(aid);
843        let model = manager.model(&expr);
844        let evaluator = model.load(&dataset);
845        let result = evaluator.evaluate(&[]);
846        assert_eq!(result[0], Complex::new(2.0, 3.0));
847    }
848
849    #[test]
850    fn test_parametric_amplitude() {
851        let mut manager = Manager::default();
852        let amp = ComplexScalar::new(
853            "parametric",
854            parameter("test_param_re"),
855            parameter("test_param_im"),
856        );
857        let aid = manager.register(amp).unwrap();
858        let dataset = Arc::new(test_dataset());
859        let expr = Expression::Amp(aid);
860        let model = manager.model(&expr);
861        let evaluator = model.load(&dataset);
862        let result = evaluator.evaluate(&[2.0, 3.0]);
863        assert_eq!(result[0], Complex::new(2.0, 3.0));
864    }
865
866    #[test]
867    fn test_expression_operations() {
868        let mut manager = Manager::default();
869        let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
870        let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
871        let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
872
873        let aid1 = manager.register(amp1).unwrap();
874        let aid2 = manager.register(amp2).unwrap();
875        let aid3 = manager.register(amp3).unwrap();
876
877        let dataset = Arc::new(test_dataset());
878
879        // Test (amp) addition
880        let expr_add = &aid1 + &aid2;
881        let model_add = manager.model(&expr_add);
882        let eval_add = model_add.load(&dataset);
883        let result_add = eval_add.evaluate(&[]);
884        assert_eq!(result_add[0], Complex::new(2.0, 1.0));
885
886        // Test (amp) multiplication
887        let expr_mul = &aid1 * &aid2;
888        let model_mul = manager.model(&expr_mul);
889        let eval_mul = model_mul.load(&dataset);
890        let result_mul = eval_mul.evaluate(&[]);
891        assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
892
893        // Test (expr) addition
894        let expr_add2 = &expr_add + &expr_mul;
895        let model_add2 = manager.model(&expr_add2);
896        let eval_add2 = model_add2.load(&dataset);
897        let result_add2 = eval_add2.evaluate(&[]);
898        assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
899
900        // Test (expr) multiplication
901        let expr_mul2 = &expr_add * &expr_mul;
902        let model_mul2 = manager.model(&expr_mul2);
903        let eval_mul2 = model_mul2.load(&dataset);
904        let result_mul2 = eval_mul2.evaluate(&[]);
905        assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
906
907        // Test (amp) real
908        let expr_real = aid3.real();
909        let model_real = manager.model(&expr_real);
910        let eval_real = model_real.load(&dataset);
911        let result_real = eval_real.evaluate(&[]);
912        assert_eq!(result_real[0], Complex::new(3.0, 0.0));
913
914        // Test (expr) real
915        let expr_mul2_real = expr_mul2.real();
916        let model_mul2_real = manager.model(&expr_mul2_real);
917        let eval_mul2_real = model_mul2_real.load(&dataset);
918        let result_mul2_real = eval_mul2_real.evaluate(&[]);
919        assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
920
921        // Test (expr) imag
922        let expr_mul2_imag = expr_mul2.imag();
923        let model_mul2_imag = manager.model(&expr_mul2_imag);
924        let eval_mul2_imag = model_mul2_imag.load(&dataset);
925        let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
926        assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
927
928        // Test (amp) imag
929        let expr_imag = aid3.imag();
930        let model_imag = manager.model(&expr_imag);
931        let eval_imag = model_imag.load(&dataset);
932        let result_imag = eval_imag.evaluate(&[]);
933        assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
934
935        // Test (amp) norm_sqr
936        let expr_norm = aid1.norm_sqr();
937        let model_norm = manager.model(&expr_norm);
938        let eval_norm = model_norm.load(&dataset);
939        let result_norm = eval_norm.evaluate(&[]);
940        assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
941
942        // Test (expr) norm_sqr
943        let expr_mul2_norm = expr_mul2.norm_sqr();
944        let model_mul2_norm = manager.model(&expr_mul2_norm);
945        let eval_mul2_norm = model_mul2_norm.load(&dataset);
946        let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
947        assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
948    }
949
950    #[test]
951    fn test_amplitude_activation() {
952        let mut manager = Manager::default();
953        let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
954        let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
955
956        let aid1 = manager.register(amp1).unwrap();
957        let aid2 = manager.register(amp2).unwrap();
958
959        let dataset = Arc::new(test_dataset());
960        let expr = &aid1 + &aid2;
961        let model = manager.model(&expr);
962        let evaluator = model.load(&dataset);
963
964        // Test initial state (all active)
965        let result = evaluator.evaluate(&[]);
966        assert_eq!(result[0], Complex::new(3.0, 0.0));
967
968        // Test deactivation
969        evaluator.deactivate("const1").unwrap();
970        let result = evaluator.evaluate(&[]);
971        assert_eq!(result[0], Complex::new(2.0, 0.0));
972
973        // Test isolation
974        evaluator.isolate("const1").unwrap();
975        let result = evaluator.evaluate(&[]);
976        assert_eq!(result[0], Complex::new(1.0, 0.0));
977
978        // Test reactivation
979        evaluator.activate_all();
980        let result = evaluator.evaluate(&[]);
981        assert_eq!(result[0], Complex::new(3.0, 0.0));
982    }
983
984    #[test]
985    fn test_gradient() {
986        let mut manager = Manager::default();
987        let amp1 = ComplexScalar::new(
988            "parametric_1",
989            parameter("test_param_re_1"),
990            parameter("test_param_im_1"),
991        );
992        let amp2 = ComplexScalar::new(
993            "parametric_2",
994            parameter("test_param_re_2"),
995            parameter("test_param_im_2"),
996        );
997
998        let aid1 = manager.register(amp1).unwrap();
999        let aid2 = manager.register(amp2).unwrap();
1000        let dataset = Arc::new(test_dataset());
1001        let params = vec![2.0, 3.0, 4.0, 5.0];
1002
1003        let expr = &aid1 * &aid2;
1004        let model = manager.model(&expr);
1005        let evaluator = model.load(&dataset);
1006
1007        let gradient = evaluator.evaluate_gradient(&params);
1008
1009        assert_relative_eq!(gradient[0][0].re, 4.0);
1010        assert_relative_eq!(gradient[0][0].im, 5.0);
1011        assert_relative_eq!(gradient[0][1].re, -5.0);
1012        assert_relative_eq!(gradient[0][1].im, 4.0);
1013        assert_relative_eq!(gradient[0][2].re, 2.0);
1014        assert_relative_eq!(gradient[0][2].im, 3.0);
1015        assert_relative_eq!(gradient[0][3].re, -3.0);
1016        assert_relative_eq!(gradient[0][3].im, 2.0);
1017
1018        let expr = (&aid1 * &aid2).real();
1019        let model = manager.model(&expr);
1020        let evaluator = model.load(&dataset);
1021
1022        let gradient = evaluator.evaluate_gradient(&params);
1023
1024        assert_relative_eq!(gradient[0][0].re, 4.0);
1025        assert_relative_eq!(gradient[0][0].im, 0.0);
1026        assert_relative_eq!(gradient[0][1].re, -5.0);
1027        assert_relative_eq!(gradient[0][1].im, 0.0);
1028        assert_relative_eq!(gradient[0][2].re, 2.0);
1029        assert_relative_eq!(gradient[0][2].im, 0.0);
1030        assert_relative_eq!(gradient[0][3].re, -3.0);
1031        assert_relative_eq!(gradient[0][3].im, 0.0);
1032
1033        let expr = (&aid1 * &aid2).imag();
1034        let model = manager.model(&expr);
1035        let evaluator = model.load(&dataset);
1036
1037        let gradient = evaluator.evaluate_gradient(&params);
1038
1039        assert_relative_eq!(gradient[0][0].re, 5.0);
1040        assert_relative_eq!(gradient[0][0].im, 0.0);
1041        assert_relative_eq!(gradient[0][1].re, 4.0);
1042        assert_relative_eq!(gradient[0][1].im, 0.0);
1043        assert_relative_eq!(gradient[0][2].re, 3.0);
1044        assert_relative_eq!(gradient[0][2].im, 0.0);
1045        assert_relative_eq!(gradient[0][3].re, 2.0);
1046        assert_relative_eq!(gradient[0][3].im, 0.0);
1047
1048        let expr = (&aid1 * &aid2).norm_sqr();
1049        let model = manager.model(&expr);
1050        let evaluator = model.load(&dataset);
1051
1052        let gradient = evaluator.evaluate_gradient(&params);
1053
1054        assert_relative_eq!(gradient[0][0].re, 164.0);
1055        assert_relative_eq!(gradient[0][0].im, 0.0);
1056        assert_relative_eq!(gradient[0][1].re, 246.0);
1057        assert_relative_eq!(gradient[0][1].im, 0.0);
1058        assert_relative_eq!(gradient[0][2].re, 104.0);
1059        assert_relative_eq!(gradient[0][2].im, 0.0);
1060        assert_relative_eq!(gradient[0][3].re, 130.0);
1061        assert_relative_eq!(gradient[0][3].im, 0.0);
1062    }
1063
1064    #[test]
1065    fn test_zeros_and_ones() {
1066        let mut manager = Manager::default();
1067        let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1068        let aid = manager.register(amp).unwrap();
1069        let dataset = Arc::new(test_dataset());
1070        let expr = (aid * Expression::One + Expression::Zero).norm_sqr();
1071        let model = manager.model(&expr);
1072        let evaluator = model.load(&dataset);
1073
1074        let params = vec![2.0];
1075        let value = evaluator.evaluate(&params);
1076        let gradient = evaluator.evaluate_gradient(&params);
1077
1078        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the value should be x^2 + 4
1079        assert_relative_eq!(value[0].re, 8.0);
1080        assert_relative_eq!(value[0].im, 0.0);
1081
1082        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the derivative should be 2x
1083        assert_relative_eq!(gradient[0][0].re, 4.0);
1084        assert_relative_eq!(gradient[0][0].im, 0.0);
1085    }
1086
1087    #[test]
1088    fn test_parameter_registration() {
1089        let mut manager = Manager::default();
1090        let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1091
1092        let aid = manager.register(amp).unwrap();
1093        let parameters = manager.parameters();
1094        let model = manager.model(&aid.into());
1095        let model_parameters = model.parameters();
1096        assert_eq!(parameters.len(), 1);
1097        assert_eq!(parameters[0], "test_param_re");
1098        assert_eq!(model_parameters.len(), 1);
1099        assert_eq!(model_parameters[0], "test_param_re");
1100    }
1101
1102    #[test]
1103    fn test_duplicate_amplitude_registration() {
1104        let mut manager = Manager::default();
1105        let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
1106        let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
1107        manager.register(amp1).unwrap();
1108        assert!(manager.register(amp2).is_err());
1109    }
1110
1111    #[test]
1112    fn test_tree_printing() {
1113        let mut manager = Manager::default();
1114        let amp1 = ComplexScalar::new(
1115            "parametric_1",
1116            parameter("test_param_re_1"),
1117            parameter("test_param_im_1"),
1118        );
1119        let amp2 = ComplexScalar::new(
1120            "parametric_2",
1121            parameter("test_param_re_2"),
1122            parameter("test_param_im_2"),
1123        );
1124        let aid1 = manager.register(amp1).unwrap();
1125        let aid2 = manager.register(amp2).unwrap();
1126        let expr = &aid1.real()
1127            + &aid2.imag()
1128            + Expression::One * Expression::Zero
1129            + (&aid1 * &aid2).norm_sqr();
1130        assert_eq!(
1131            expr.to_string(),
1132            "+
1133├─ +
1134│  ├─ +
1135│  │  ├─ Re
1136│  │  │  └─ parametric_1(id=0)
1137│  │  └─ Im
1138│  │     └─ parametric_2(id=1)
1139│  └─ *
1140│     ├─ 1
1141│     └─ 0
1142└─ NormSqr
1143   └─ *
1144      ├─ parametric_1(id=0)
1145      └─ parametric_2(id=1)
1146"
1147        );
1148    }
1149}