laddu_core/
amplitudes.rs

1use std::{
2    fmt::{Debug, Display},
3    sync::Arc,
4};
5
6use auto_ops::*;
7use dyn_clone::DynClone;
8use nalgebra::{ComplexField, DVector};
9use num::Complex;
10
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use crate::{
17    data::{Dataset, Event},
18    resources::{Cache, Parameters, Resources},
19    Float, LadduError,
20};
21
22#[cfg(feature = "mpi")]
23use crate::mpi::LadduMPI;
24
25#[cfg(feature = "mpi")]
26use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
27
28/// An enum containing either a named free parameter or a constant value.
29#[derive(Clone, Default, Serialize, Deserialize)]
30pub enum ParameterLike {
31    /// A named free parameter.
32    Parameter(String),
33    /// A constant value.
34    Constant(Float),
35    /// An uninitialized parameter-like structure (typically used as the value given in an
36    /// [`Amplitude`] constructor before the [`Amplitude::register`] method is called).
37    #[default]
38    Uninit,
39}
40
41/// Shorthand for generating a named free parameter.
42pub fn parameter(name: &str) -> ParameterLike {
43    ParameterLike::Parameter(name.to_string())
44}
45
46/// Shorthand for generating a constant value (which acts like a fixed parameter).
47pub fn constant(value: Float) -> ParameterLike {
48    ParameterLike::Constant(value)
49}
50
51/// This is the only required trait for writing new amplitude-like structures for this
52/// crate. Users need only implement the [`register`](Amplitude::register)
53/// method to register parameters, cached values, and the amplitude itself with an input
54/// [`Resources`] struct and the [`compute`](Amplitude::compute) method to actually carry
55/// out the calculation. [`Amplitude`]-implementors are required to implement [`Clone`] and can
56/// optionally implement a [`precompute`](Amplitude::precompute) method to calculate and
57/// cache values which do not depend on free parameters.
58#[typetag::serde(tag = "type")]
59pub trait Amplitude: DynClone + Send + Sync {
60    /// This method should be used to tell the [`Resources`] manager about all of
61    /// the free parameters and cached values used by this [`Amplitude`]. It should end by
62    /// returning an [`AmplitudeID`], which can be obtained from the
63    /// [`Resources::register_amplitude`] method.
64    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
65    /// This method can be used to do some critical calculations ahead of time and
66    /// store them in a [`Cache`]. These values can only depend on the data in an [`Event`],
67    /// not on any free parameters in the fit. This method is opt-in since it is not required
68    /// to make a functioning [`Amplitude`].
69    #[allow(unused_variables)]
70    fn precompute(&self, event: &Event, cache: &mut Cache) {}
71    /// Evaluates [`Amplitude::precompute`] over ever [`Event`] in a [`Dataset`].
72    #[cfg(feature = "rayon")]
73    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
74        dataset
75            .events
76            .par_iter()
77            .zip(resources.caches.par_iter_mut())
78            .for_each(|(event, cache)| {
79                self.precompute(event, cache);
80            })
81    }
82    /// Evaluates [`Amplitude::precompute`] over ever [`Event`] in a [`Dataset`].
83    #[cfg(not(feature = "rayon"))]
84    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
85        dataset
86            .events
87            .iter()
88            .zip(resources.caches.iter_mut())
89            .for_each(|(event, cache)| self.precompute(event, cache))
90    }
91    /// This method constitutes the main machinery of an [`Amplitude`], returning the actual
92    /// calculated value for a particular [`Event`] and set of [`Parameters`]. See those
93    /// structs, as well as [`Cache`], for documentation on their available methods. For the
94    /// most part, [`Event`]s can be interacted with via
95    /// [`Variable`](crate::utils::variables::Variable)s, while [`Parameters`] and the
96    /// [`Cache`] are more like key-value storage accessed by
97    /// [`ParameterID`](crate::resources::ParameterID)s and several different types of cache
98    /// IDs.
99    fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
100
101    /// This method yields the gradient of a particular [`Amplitude`] at a point specified
102    /// by a particular [`Event`] and set of [`Parameters`]. See those structs, as well as
103    /// [`Cache`], for documentation on their available methods. For the most part,
104    /// [`Event`]s can be interacted with via [`Variable`](crate::utils::variables::Variable)s,
105    /// while [`Parameters`] and the [`Cache`] are more like key-value storage accessed by
106    /// [`ParameterID`](crate::resources::ParameterID)s and several different types of cache
107    /// IDs. If the analytic version of the gradient is known, this method can be overwritten to
108    /// improve performance for some derivative-using methods of minimization. The default
109    /// implementation calculates a central finite difference across all parameters, regardless of
110    /// whether or not they are used in the [`Amplitude`].
111    ///
112    /// In the future, it may be possible to automatically implement this with the indices of
113    /// registered free parameters, but until then, the [`Amplitude::central_difference_with_indices`]
114    /// method can be used to conveniently only calculate central differences for the parameters
115    /// which are used by the [`Amplitude`].
116    fn compute_gradient(
117        &self,
118        parameters: &Parameters,
119        event: &Event,
120        cache: &Cache,
121        gradient: &mut DVector<Complex<Float>>,
122    ) {
123        self.central_difference_with_indices(
124            &Vec::from_iter(0..parameters.len()),
125            parameters,
126            event,
127            cache,
128            gradient,
129        )
130    }
131
132    /// A helper function to implement a central difference only on indices which correspond to
133    /// free parameters in the [`Amplitude`]. For example, if an [`Amplitude`] contains free
134    /// parameters registered to indices 1, 3, and 5 of the its internal parameters array, then
135    /// running this with those indices will compute a central finite difference derivative for
136    /// those coordinates only, since the rest can be safely assumed to be zero.
137    fn central_difference_with_indices(
138        &self,
139        indices: &[usize],
140        parameters: &Parameters,
141        event: &Event,
142        cache: &Cache,
143        gradient: &mut DVector<Complex<Float>>,
144    ) {
145        let x = parameters.parameters.to_owned();
146        let constants = parameters.constants.to_owned();
147        let h: DVector<Float> = x
148            .iter()
149            .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
150            .collect::<Vec<_>>()
151            .into();
152        for i in indices {
153            let mut x_plus = x.clone();
154            let mut x_minus = x.clone();
155            x_plus[*i] += h[*i];
156            x_minus[*i] -= h[*i];
157            let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
158            let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
159            gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
160        }
161    }
162}
163
164/// Utility function to calculate a central finite difference gradient.
165pub fn central_difference<F: Fn(&[Float]) -> Float>(
166    parameters: &[Float],
167    func: F,
168) -> DVector<Float> {
169    let mut gradient = DVector::zeros(parameters.len());
170    let x = parameters.to_owned();
171    let h: DVector<Float> = x
172        .iter()
173        .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
174        .collect::<Vec<_>>()
175        .into();
176    for i in 0..parameters.len() {
177        let mut x_plus = x.clone();
178        let mut x_minus = x.clone();
179        x_plus[i] += h[i];
180        x_minus[i] -= h[i];
181        let f_plus = func(&x_plus);
182        let f_minus = func(&x_minus);
183        gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
184    }
185    gradient
186}
187
188dyn_clone::clone_trait_object!(Amplitude);
189
190/// A helper struct that contains the value of each amplitude for a particular event
191#[derive(Debug)]
192pub struct AmplitudeValues(pub Vec<Complex<Float>>);
193
194/// A helper struct that contains the gradient of each amplitude for a particular event
195#[derive(Debug)]
196pub struct GradientValues(pub Vec<DVector<Complex<Float>>>);
197
198/// A tag which refers to a registered [`Amplitude`]. This is the base object which can be used to
199/// build [`Expression`]s and should be obtained from the [`Manager::register`] method.
200#[derive(Clone, Default, Debug, Serialize, Deserialize)]
201pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
202
203impl Display for AmplitudeID {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        write!(f, "{}", self.0)
206    }
207}
208
209impl From<AmplitudeID> for Expression {
210    fn from(value: AmplitudeID) -> Self {
211        Self::Amp(value)
212    }
213}
214
215/// An expression tree which contains [`AmplitudeID`]s and operators over them.
216#[derive(Clone, Serialize, Deserialize, Default)]
217pub enum Expression {
218    #[default]
219    /// A expression equal to zero.
220    Zero,
221    /// A registered [`Amplitude`] referenced by an [`AmplitudeID`].
222    Amp(AmplitudeID),
223    /// The sum of two [`Expression`]s.
224    Add(Box<Expression>, Box<Expression>),
225    /// The product of two [`Expression`]s.
226    Mul(Box<Expression>, Box<Expression>),
227    /// The real part of an [`Expression`].
228    Real(Box<Expression>),
229    /// The imaginary part of an [`Expression`].
230    Imag(Box<Expression>),
231    /// The absolute square of an [`Expression`].
232    NormSqr(Box<Expression>),
233}
234
235impl Debug for Expression {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        self.write_tree(f, "", "", "")
238    }
239}
240
241impl Display for Expression {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        write!(f, "{:?}", self)
244    }
245}
246
247impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression { Expression::Add(Box::new(a.clone()), Box::new(b.clone()))});
248impl_op_ex!(*|a: &Expression, b: &Expression| -> Expression {
249    Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
250});
251impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))});
252impl_op_ex_commutative!(*|a: &AmplitudeID, b: &Expression| -> Expression {
253    Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
254});
255impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(Expression::Amp(b.clone())))});
256impl_op_ex!(*|a: &AmplitudeID, b: &AmplitudeID| -> Expression {
257    Expression::Mul(
258        Box::new(Expression::Amp(a.clone())),
259        Box::new(Expression::Amp(b.clone())),
260    )
261});
262
263impl AmplitudeID {
264    /// Takes the real part of the given [`Amplitude`].
265    pub fn real(&self) -> Expression {
266        Expression::Real(Box::new(Expression::Amp(self.clone())))
267    }
268    /// Takes the imaginary part of the given [`Amplitude`].
269    pub fn imag(&self) -> Expression {
270        Expression::Imag(Box::new(Expression::Amp(self.clone())))
271    }
272    /// Takes the absolute square of the given [`Amplitude`].
273    pub fn norm_sqr(&self) -> Expression {
274        Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
275    }
276}
277
278impl Expression {
279    /// Evaluate an [`Expression`] over a single event using calculated [`AmplitudeValues`]
280    ///
281    /// This method parses the underlying [`Expression`] but doesn't actually calculate the values
282    /// from the [`Amplitude`]s themselves.
283    pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
284        match self {
285            Expression::Amp(aid) => amplitude_values.0[aid.1],
286            Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
287            Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
288            Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
289            Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
290            Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
291            Expression::Zero => Complex::ZERO,
292        }
293    }
294    /// Evaluate the gradient of an [`Expression`] over a single event using calculated [`AmplitudeValues`]
295    ///
296    /// This method parses the underlying [`Expression`] but doesn't actually calculate the
297    /// gradient from the [`Amplitude`]s themselves.
298    pub fn evaluate_gradient(
299        &self,
300        amplitude_values: &AmplitudeValues,
301        gradient_values: &GradientValues,
302    ) -> DVector<Complex<Float>> {
303        match self {
304            Expression::Amp(aid) => gradient_values.0[aid.1].clone(),
305            Expression::Add(a, b) => {
306                a.evaluate_gradient(amplitude_values, gradient_values)
307                    + b.evaluate_gradient(amplitude_values, gradient_values)
308            }
309            Expression::Mul(a, b) => {
310                let f_a = a.evaluate(amplitude_values);
311                let f_b = b.evaluate(amplitude_values);
312                b.evaluate_gradient(amplitude_values, gradient_values)
313                    .map(|g| g * f_a)
314                    + a.evaluate_gradient(amplitude_values, gradient_values)
315                        .map(|g| g * f_b)
316            }
317            Expression::Real(a) => a
318                .evaluate_gradient(amplitude_values, gradient_values)
319                .map(|g| Complex::new(g.re, 0.0)),
320            Expression::Imag(a) => a
321                .evaluate_gradient(amplitude_values, gradient_values)
322                .map(|g| Complex::new(g.im, 0.0)),
323            Expression::NormSqr(a) => {
324                let conj_f_a = a.evaluate(amplitude_values).conjugate();
325                a.evaluate_gradient(amplitude_values, gradient_values)
326                    .map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
327            }
328            Expression::Zero => DVector::zeros(0),
329        }
330    }
331    /// Takes the real part of the given [`Expression`].
332    pub fn real(&self) -> Self {
333        Self::Real(Box::new(self.clone()))
334    }
335    /// Takes the imaginary part of the given [`Expression`].
336    pub fn imag(&self) -> Self {
337        Self::Imag(Box::new(self.clone()))
338    }
339    /// Takes the absolute square of the given [`Expression`].
340    pub fn norm_sqr(&self) -> Self {
341        Self::NormSqr(Box::new(self.clone()))
342    }
343
344    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
345    fn write_tree(
346        &self,
347        f: &mut std::fmt::Formatter<'_>,
348        parent_prefix: &str,
349        immediate_prefix: &str,
350        parent_suffix: &str,
351    ) -> std::fmt::Result {
352        let display_string = match self {
353            Self::Amp(aid) => aid.0.clone(),
354            Self::Add(_, _) => "+".to_string(),
355            Self::Mul(_, _) => "*".to_string(),
356            Self::Real(_) => "Re".to_string(),
357            Self::Imag(_) => "Im".to_string(),
358            Self::NormSqr(_) => "NormSqr".to_string(),
359            Self::Zero => "0".to_string(),
360        };
361        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
362        match self {
363            Self::Amp(_) | Self::Zero => {}
364            Self::Add(a, b) | Self::Mul(a, b) => {
365                let terms = [a, b];
366                let mut it = terms.iter().peekable();
367                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
368                while let Some(child) = it.next() {
369                    match it.peek() {
370                        Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│  "),
371                        None => child.write_tree(f, &child_prefix, "└─ ", "   "),
372                    }?;
373                }
374            }
375            Self::Real(a) | Self::Imag(a) | Self::NormSqr(a) => {
376                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
377                a.write_tree(f, &child_prefix, "└─ ", "   ")?;
378            }
379        }
380        Ok(())
381    }
382}
383
384/// A manager which can be used to register [`Amplitude`]s with [`Resources`]. This structure is
385/// essential to any analysis and should be constructed using the [`Manager::default()`] method.
386#[derive(Default, Clone, Serialize, Deserialize)]
387pub struct Manager {
388    amplitudes: Vec<Box<dyn Amplitude>>,
389    resources: Resources,
390}
391
392impl Manager {
393    /// Get the list of parameter names in the order they appear in the [`Manager`]'s [`Resources`] field.
394    pub fn parameters(&self) -> Vec<String> {
395        self.resources.parameters.iter().cloned().collect()
396    }
397    /// Register the given [`Amplitude`] and return an [`AmplitudeID`] that can be used to build
398    /// [`Expression`]s.
399    ///
400    /// # Errors
401    ///
402    /// The [`Amplitude`]'s name must be unique and not already
403    /// registered, else this will return a [`RegistrationError`][LadduError::RegistrationError].
404    pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
405        let mut amp = amplitude.clone();
406        let aid = amp.register(&mut self.resources)?;
407        self.amplitudes.push(amp);
408        Ok(aid)
409    }
410    /// Turns an [`Expression`] made from registered [`Amplitude`]s into a [`Model`].
411    pub fn model(&self, expression: &Expression) -> Model {
412        Model {
413            manager: self.clone(),
414            expression: expression.clone(),
415        }
416    }
417}
418
419/// A struct which contains a set of registerd [`Amplitude`]s (inside a [`Manager`])
420/// and an [`Expression`].
421///
422/// This struct implements [`serde::Serialize`] and [`serde::Deserialize`] and is intended
423/// to be used to store models to disk.
424#[derive(Clone, Serialize, Deserialize)]
425pub struct Model {
426    pub(crate) manager: Manager,
427    pub(crate) expression: Expression,
428}
429
430impl Model {
431    /// Get the list of parameter names in the order they appear in the [`Model`]'s [`Manager`] field.
432    pub fn parameters(&self) -> Vec<String> {
433        self.manager.parameters()
434    }
435    /// Create an [`Evaluator`] which can compute the result of the internal [`Expression`] built on
436    /// registered [`Amplitude`]s over the given [`Dataset`]. This method precomputes any relevant
437    /// information over the [`Event`]s in the [`Dataset`].
438    pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
439        let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
440        loaded_resources.write().reserve_cache(dataset.n_events());
441        for amplitude in &self.manager.amplitudes {
442            amplitude.precompute_all(dataset, &mut loaded_resources.write());
443        }
444        Evaluator {
445            amplitudes: self.manager.amplitudes.clone(),
446            resources: loaded_resources.clone(),
447            dataset: dataset.clone(),
448            expression: self.expression.clone(),
449        }
450    }
451}
452
453/// A structure which can be used to evaluate the stored [`Expression`] built on registered
454/// [`Amplitude`]s. This contains a [`Resources`] struct which already contains cached values for
455/// precomputed [`Amplitude`]s and any relevant free parameters and constants.
456#[derive(Clone)]
457pub struct Evaluator {
458    /// A list of [`Amplitude`]s which were registered with the [`Manager`] used to create the
459    /// internal [`Expression`]. This includes but is not limited to those which are actually used
460    /// in the [`Expression`].
461    pub amplitudes: Vec<Box<dyn Amplitude>>,
462    /// The internal [`Resources`] where precalculated values are stored
463    pub resources: Arc<RwLock<Resources>>,
464    /// The internal [`Dataset`]
465    pub dataset: Arc<Dataset>,
466    /// The internal [`Expression`]
467    pub expression: Expression,
468}
469
470impl Evaluator {
471    /// Get the list of parameter names in the order they appear in the [`Evaluator::evaluate`]
472    /// method.
473    pub fn parameters(&self) -> Vec<String> {
474        self.resources.read().parameters.iter().cloned().collect()
475    }
476    /// Activate an [`Amplitude`] by name.
477    pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
478        self.resources.write().activate(name)
479    }
480    /// Activate several [`Amplitude`]s by name.
481    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
482        self.resources.write().activate_many(names)
483    }
484    /// Activate all registered [`Amplitude`]s.
485    pub fn activate_all(&self) {
486        self.resources.write().activate_all();
487    }
488    /// Dectivate an [`Amplitude`] by name.
489    pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
490        self.resources.write().deactivate(name)
491    }
492    /// Deactivate several [`Amplitude`]s by name.
493    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
494        self.resources.write().deactivate_many(names)
495    }
496    /// Deactivate all registered [`Amplitude`]s.
497    pub fn deactivate_all(&self) {
498        self.resources.write().deactivate_all();
499    }
500    /// Isolate an [`Amplitude`] by name (deactivate the rest).
501    pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
502        self.resources.write().isolate(name)
503    }
504    /// Isolate several [`Amplitude`]s by name (deactivate the rest).
505    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
506        self.resources.write().isolate_many(names)
507    }
508
509    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
510    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
511    ///
512    /// # Notes
513    ///
514    /// This method is not intended to be called in analyses but rather in writing methods
515    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
516    pub fn evaluate_local(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
517        let resources = self.resources.read();
518        let parameters = Parameters::new(parameters, &resources.constants);
519        #[cfg(feature = "rayon")]
520        {
521            let amplitude_values_vec: Vec<AmplitudeValues> = self
522                .dataset
523                .events
524                .par_iter()
525                .zip(resources.caches.par_iter())
526                .map(|(event, cache)| {
527                    AmplitudeValues(
528                        self.amplitudes
529                            .iter()
530                            .zip(resources.active.iter())
531                            .map(|(amp, active)| {
532                                if *active {
533                                    amp.compute(&parameters, event, cache)
534                                } else {
535                                    Complex::new(0.0, 0.0)
536                                }
537                            })
538                            .collect(),
539                    )
540                })
541                .collect();
542            amplitude_values_vec
543                .par_iter()
544                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
545                .collect()
546        }
547        #[cfg(not(feature = "rayon"))]
548        {
549            let amplitude_values_vec: Vec<AmplitudeValues> = self
550                .dataset
551                .events
552                .iter()
553                .zip(resources.caches.iter())
554                .map(|(event, cache)| {
555                    AmplitudeValues(
556                        self.amplitudes
557                            .iter()
558                            .zip(resources.active.iter())
559                            .map(|(amp, active)| {
560                                if *active {
561                                    amp.compute(&parameters, event, cache)
562                                } else {
563                                    Complex::new(0.0, 0.0)
564                                }
565                            })
566                            .collect(),
567                    )
568                })
569                .collect();
570            amplitude_values_vec
571                .iter()
572                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
573                .collect()
574        }
575    }
576
577    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
578    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
579    ///
580    /// # Notes
581    ///
582    /// This method is not intended to be called in analyses but rather in writing methods
583    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
584    #[cfg(feature = "mpi")]
585    fn evaluate_mpi(
586        &self,
587        parameters: &[Float],
588        world: &SimpleCommunicator,
589    ) -> Vec<Complex<Float>> {
590        let local_evaluation = self.evaluate_local(parameters);
591        let n_events = self.dataset.n_events();
592        let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; n_events];
593        let (counts, displs) = world.get_counts_displs(n_events);
594        {
595            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
596            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
597        }
598        buffer
599    }
600
601    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
602    /// [`Evaluator`] with the given values for free parameters.
603    pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
604        #[cfg(feature = "mpi")]
605        {
606            if let Some(world) = crate::mpi::get_world() {
607                return self.evaluate_mpi(parameters, &world);
608            }
609        }
610        self.evaluate_local(parameters)
611    }
612
613    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
614    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
615    ///
616    /// # Notes
617    ///
618    /// This method is not intended to be called in analyses but rather in writing methods
619    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
620    pub fn evaluate_gradient_local(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
621        let resources = self.resources.read();
622        let parameters = Parameters::new(parameters, &resources.constants);
623        #[cfg(feature = "rayon")]
624        {
625            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
626                .dataset
627                .events
628                .par_iter()
629                .zip(resources.caches.par_iter())
630                .map(|(event, cache)| {
631                    let mut gradient_values =
632                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
633                    self.amplitudes
634                        .iter()
635                        .zip(resources.active.iter())
636                        .zip(gradient_values.iter_mut())
637                        .for_each(|((amp, active), grad)| {
638                            if *active {
639                                amp.compute_gradient(&parameters, event, cache, grad)
640                            }
641                        });
642                    (
643                        AmplitudeValues(
644                            self.amplitudes
645                                .iter()
646                                .zip(resources.active.iter())
647                                .map(|(amp, active)| {
648                                    if *active {
649                                        amp.compute(&parameters, event, cache)
650                                    } else {
651                                        Complex::new(0.0, 0.0)
652                                    }
653                                })
654                                .collect(),
655                        ),
656                        GradientValues(gradient_values),
657                    )
658                })
659                .collect();
660            amplitude_values_and_gradient_vec
661                .par_iter()
662                .map(|(amplitude_values, gradient_values)| {
663                    self.expression
664                        .evaluate_gradient(amplitude_values, gradient_values)
665                })
666                .collect()
667        }
668        #[cfg(not(feature = "rayon"))]
669        {
670            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
671                .dataset
672                .events
673                .iter()
674                .zip(resources.caches.iter())
675                .map(|(event, cache)| {
676                    let mut gradient_values =
677                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
678                    self.amplitudes
679                        .iter()
680                        .zip(resources.active.iter())
681                        .zip(gradient_values.iter_mut())
682                        .for_each(|((amp, active), grad)| {
683                            if *active {
684                                amp.compute_gradient(&parameters, event, cache, grad)
685                            }
686                        });
687                    (
688                        AmplitudeValues(
689                            self.amplitudes
690                                .iter()
691                                .zip(resources.active.iter())
692                                .map(|(amp, active)| {
693                                    if *active {
694                                        amp.compute(&parameters, event, cache)
695                                    } else {
696                                        Complex::new(0.0, 0.0)
697                                    }
698                                })
699                                .collect(),
700                        ),
701                        GradientValues(gradient_values),
702                    )
703                })
704                .collect();
705
706            amplitude_values_and_gradient_vec
707                .iter()
708                .map(|(amplitude_values, gradient_values)| {
709                    self.expression
710                        .evaluate_gradient(amplitude_values, gradient_values)
711                })
712                .collect()
713        }
714    }
715
716    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
717    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
718    ///
719    /// # Notes
720    ///
721    /// This method is not intended to be called in analyses but rather in writing methods
722    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
723    #[cfg(feature = "mpi")]
724    fn evaluate_gradient_mpi(
725        &self,
726        parameters: &[Float],
727        world: &SimpleCommunicator,
728    ) -> Vec<DVector<Complex<Float>>> {
729        let flattened_local_evaluation = self
730            .evaluate_gradient_local(parameters)
731            .iter()
732            .flat_map(|g| g.data.as_vec().to_vec())
733            .collect::<Vec<Complex<Float>>>();
734        let n_events = self.dataset.n_events();
735        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
736        let mut flattened_result_buffer = vec![Complex::ZERO; n_events * parameters.len()];
737        let mut partitioned_flattened_result_buffer =
738            PartitionMut::new(&mut flattened_result_buffer, counts, displs);
739        world.all_gather_varcount_into(
740            &flattened_local_evaluation,
741            &mut partitioned_flattened_result_buffer,
742        );
743        flattened_result_buffer
744            .chunks(parameters.len())
745            .map(DVector::from_row_slice)
746            .collect()
747    }
748
749    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
750    /// [`Evaluator`] with the given values for free parameters.
751    pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
752        #[cfg(feature = "mpi")]
753        {
754            if let Some(world) = crate::mpi::get_world() {
755                return self.evaluate_gradient_mpi(parameters, &world);
756            }
757        }
758        self.evaluate_gradient_local(parameters)
759    }
760}
761
762#[cfg(test)]
763mod tests {
764    use crate::data::{test_dataset, test_event};
765
766    use super::*;
767    use crate::{
768        data::Event,
769        resources::{Cache, ParameterID, Parameters, Resources},
770        Complex, DVector, Float, LadduError,
771    };
772    use approx::assert_relative_eq;
773    use serde::{Deserialize, Serialize};
774
775    #[derive(Clone, Serialize, Deserialize)]
776    pub struct ComplexScalar {
777        name: String,
778        re: ParameterLike,
779        pid_re: ParameterID,
780        im: ParameterLike,
781        pid_im: ParameterID,
782    }
783
784    impl ComplexScalar {
785        pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
786            Self {
787                name: name.to_string(),
788                re,
789                pid_re: Default::default(),
790                im,
791                pid_im: Default::default(),
792            }
793            .into()
794        }
795    }
796
797    #[typetag::serde]
798    impl Amplitude for ComplexScalar {
799        fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
800            self.pid_re = resources.register_parameter(&self.re);
801            self.pid_im = resources.register_parameter(&self.im);
802            resources.register_amplitude(&self.name)
803        }
804
805        fn compute(
806            &self,
807            parameters: &Parameters,
808            _event: &Event,
809            _cache: &Cache,
810        ) -> Complex<Float> {
811            Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
812        }
813
814        fn compute_gradient(
815            &self,
816            _parameters: &Parameters,
817            _event: &Event,
818            _cache: &Cache,
819            gradient: &mut DVector<Complex<Float>>,
820        ) {
821            if let ParameterID::Parameter(ind) = self.pid_re {
822                gradient[ind] = Complex::ONE;
823            }
824            if let ParameterID::Parameter(ind) = self.pid_im {
825                gradient[ind] = Complex::I;
826            }
827        }
828    }
829
830    #[test]
831    fn test_constant_amplitude() {
832        let mut manager = Manager::default();
833        let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
834        let aid = manager.register(amp).unwrap();
835        let dataset = Arc::new(Dataset {
836            events: vec![Arc::new(test_event())],
837        });
838        let expr = Expression::Amp(aid);
839        let model = manager.model(&expr);
840        let evaluator = model.load(&dataset);
841        let result = evaluator.evaluate(&[]);
842        assert_eq!(result[0], Complex::new(2.0, 3.0));
843    }
844
845    #[test]
846    fn test_parametric_amplitude() {
847        let mut manager = Manager::default();
848        let amp = ComplexScalar::new(
849            "parametric",
850            parameter("test_param_re"),
851            parameter("test_param_im"),
852        );
853        let aid = manager.register(amp).unwrap();
854        let dataset = Arc::new(test_dataset());
855        let expr = Expression::Amp(aid);
856        let model = manager.model(&expr);
857        let evaluator = model.load(&dataset);
858        let result = evaluator.evaluate(&[2.0, 3.0]);
859        assert_eq!(result[0], Complex::new(2.0, 3.0));
860    }
861
862    #[test]
863    fn test_expression_operations() {
864        let mut manager = Manager::default();
865        let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
866        let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
867        let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
868
869        let aid1 = manager.register(amp1).unwrap();
870        let aid2 = manager.register(amp2).unwrap();
871        let aid3 = manager.register(amp3).unwrap();
872
873        let dataset = Arc::new(test_dataset());
874
875        // Test (amp) addition
876        let expr_add = &aid1 + &aid2;
877        let model_add = manager.model(&expr_add);
878        let eval_add = model_add.load(&dataset);
879        let result_add = eval_add.evaluate(&[]);
880        assert_eq!(result_add[0], Complex::new(2.0, 1.0));
881
882        // Test (amp) multiplication
883        let expr_mul = &aid1 * &aid2;
884        let model_mul = manager.model(&expr_mul);
885        let eval_mul = model_mul.load(&dataset);
886        let result_mul = eval_mul.evaluate(&[]);
887        assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
888
889        // Test (expr) addition
890        let expr_add2 = &expr_add + &expr_mul;
891        let model_add2 = manager.model(&expr_add2);
892        let eval_add2 = model_add2.load(&dataset);
893        let result_add2 = eval_add2.evaluate(&[]);
894        assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
895
896        // Test (expr) multiplication
897        let expr_mul2 = &expr_add * &expr_mul;
898        let model_mul2 = manager.model(&expr_mul2);
899        let eval_mul2 = model_mul2.load(&dataset);
900        let result_mul2 = eval_mul2.evaluate(&[]);
901        assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
902
903        // Test (amp) real
904        let expr_real = aid3.real();
905        let model_real = manager.model(&expr_real);
906        let eval_real = model_real.load(&dataset);
907        let result_real = eval_real.evaluate(&[]);
908        assert_eq!(result_real[0], Complex::new(3.0, 0.0));
909
910        // Test (expr) real
911        let expr_mul2_real = expr_mul2.real();
912        let model_mul2_real = manager.model(&expr_mul2_real);
913        let eval_mul2_real = model_mul2_real.load(&dataset);
914        let result_mul2_real = eval_mul2_real.evaluate(&[]);
915        assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
916
917        // Test (expr) imag
918        let expr_mul2_imag = expr_mul2.imag();
919        let model_mul2_imag = manager.model(&expr_mul2_imag);
920        let eval_mul2_imag = model_mul2_imag.load(&dataset);
921        let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
922        assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
923
924        // Test (amp) imag
925        let expr_imag = aid3.imag();
926        let model_imag = manager.model(&expr_imag);
927        let eval_imag = model_imag.load(&dataset);
928        let result_imag = eval_imag.evaluate(&[]);
929        assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
930
931        // Test (amp) norm_sqr
932        let expr_norm = aid1.norm_sqr();
933        let model_norm = manager.model(&expr_norm);
934        let eval_norm = model_norm.load(&dataset);
935        let result_norm = eval_norm.evaluate(&[]);
936        assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
937
938        // Test (expr) norm_sqr
939        let expr_mul2_norm = expr_mul2.norm_sqr();
940        let model_mul2_norm = manager.model(&expr_mul2_norm);
941        let eval_mul2_norm = model_mul2_norm.load(&dataset);
942        let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
943        assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
944    }
945
946    #[test]
947    fn test_amplitude_activation() {
948        let mut manager = Manager::default();
949        let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
950        let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
951
952        let aid1 = manager.register(amp1).unwrap();
953        let aid2 = manager.register(amp2).unwrap();
954
955        let dataset = Arc::new(test_dataset());
956        let expr = &aid1 + &aid2;
957        let model = manager.model(&expr);
958        let evaluator = model.load(&dataset);
959
960        // Test initial state (all active)
961        let result = evaluator.evaluate(&[]);
962        assert_eq!(result[0], Complex::new(3.0, 0.0));
963
964        // Test deactivation
965        evaluator.deactivate("const1").unwrap();
966        let result = evaluator.evaluate(&[]);
967        assert_eq!(result[0], Complex::new(2.0, 0.0));
968
969        // Test isolation
970        evaluator.isolate("const1").unwrap();
971        let result = evaluator.evaluate(&[]);
972        assert_eq!(result[0], Complex::new(1.0, 0.0));
973
974        // Test reactivation
975        evaluator.activate_all();
976        let result = evaluator.evaluate(&[]);
977        assert_eq!(result[0], Complex::new(3.0, 0.0));
978    }
979
980    #[test]
981    fn test_gradient() {
982        let mut manager = Manager::default();
983        let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
984
985        let aid = manager.register(amp).unwrap();
986        let dataset = Arc::new(test_dataset());
987        let expr = aid.norm_sqr();
988        let model = manager.model(&expr);
989        let evaluator = model.load(&dataset);
990
991        let params = vec![2.0];
992        let gradient = evaluator.evaluate_gradient(&params);
993
994        // For |f(x)|^2 where f(x) = x, the derivative should be 2x
995        assert_relative_eq!(gradient[0][0].re, 4.0);
996        assert_relative_eq!(gradient[0][0].im, 0.0);
997    }
998
999    #[test]
1000    fn test_parameter_registration() {
1001        let mut manager = Manager::default();
1002        let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1003
1004        let aid = manager.register(amp).unwrap();
1005        let parameters = manager.parameters();
1006        let model = manager.model(&aid.into());
1007        let model_parameters = model.parameters();
1008        assert_eq!(parameters.len(), 1);
1009        assert_eq!(parameters[0], "test_param_re");
1010        assert_eq!(model_parameters.len(), 1);
1011        assert_eq!(model_parameters[0], "test_param_re");
1012    }
1013
1014    #[test]
1015    fn test_duplicate_amplitude_registration() {
1016        let mut manager = Manager::default();
1017        let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
1018        let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
1019        manager.register(amp1).unwrap();
1020        assert!(manager.register(amp2).is_err());
1021    }
1022}