laddu_core/
amplitudes.rs

1use std::{
2    fmt::{Debug, Display},
3    sync::Arc,
4};
5
6use auto_ops::*;
7use dyn_clone::DynClone;
8use nalgebra::{ComplexField, DVector};
9use num::Complex;
10
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use crate::{
17    data::{Dataset, Event},
18    resources::{Cache, Parameters, Resources},
19    Float, LadduError, ParameterID, ReadWrite,
20};
21
22#[cfg(feature = "mpi")]
23use crate::mpi::LadduMPI;
24
25#[cfg(feature = "mpi")]
26use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
27
28/// An enum containing either a named free parameter or a constant value.
29#[derive(Clone, Default, Serialize, Deserialize)]
30pub enum ParameterLike {
31    /// A named free parameter.
32    Parameter(String),
33    /// A constant value.
34    Constant(Float),
35    /// An uninitialized parameter-like structure (typically used as the value given in an
36    /// [`Amplitude`] constructor before the [`Amplitude::register`] method is called).
37    #[default]
38    Uninit,
39}
40
41/// Shorthand for generating a named free parameter.
42pub fn parameter(name: &str) -> ParameterLike {
43    ParameterLike::Parameter(name.to_string())
44}
45
46/// Shorthand for generating a constant value (which acts like a fixed parameter).
47pub fn constant(value: Float) -> ParameterLike {
48    ParameterLike::Constant(value)
49}
50
51/// This is the only required trait for writing new amplitude-like structures for this
52/// crate. Users need only implement the [`register`](Amplitude::register)
53/// method to register parameters, cached values, and the amplitude itself with an input
54/// [`Resources`] struct and the [`compute`](Amplitude::compute) method to actually carry
55/// out the calculation. [`Amplitude`]-implementors are required to implement [`Clone`] and can
56/// optionally implement a [`precompute`](Amplitude::precompute) method to calculate and
57/// cache values which do not depend on free parameters.
58#[typetag::serde(tag = "type")]
59pub trait Amplitude: DynClone + Send + Sync {
60    /// This method should be used to tell the [`Resources`] manager about all of
61    /// the free parameters and cached values used by this [`Amplitude`]. It should end by
62    /// returning an [`AmplitudeID`], which can be obtained from the
63    /// [`Resources::register_amplitude`] method.
64    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
65    /// This method can be used to do some critical calculations ahead of time and
66    /// store them in a [`Cache`]. These values can only depend on the data in an [`Event`],
67    /// not on any free parameters in the fit. This method is opt-in since it is not required
68    /// to make a functioning [`Amplitude`].
69    #[allow(unused_variables)]
70    fn precompute(&self, event: &Event, cache: &mut Cache) {}
71    /// Evaluates [`Amplitude::precompute`] over ever [`Event`] in a [`Dataset`].
72    #[cfg(feature = "rayon")]
73    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
74        dataset
75            .events
76            .par_iter()
77            .zip(resources.caches.par_iter_mut())
78            .for_each(|(event, cache)| {
79                self.precompute(event, cache);
80            })
81    }
82    /// Evaluates [`Amplitude::precompute`] over ever [`Event`] in a [`Dataset`].
83    #[cfg(not(feature = "rayon"))]
84    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
85        dataset
86            .events
87            .iter()
88            .zip(resources.caches.iter_mut())
89            .for_each(|(event, cache)| self.precompute(event, cache))
90    }
91    /// This method constitutes the main machinery of an [`Amplitude`], returning the actual
92    /// calculated value for a particular [`Event`] and set of [`Parameters`]. See those
93    /// structs, as well as [`Cache`], for documentation on their available methods. For the
94    /// most part, [`Event`]s can be interacted with via
95    /// [`Variable`](crate::utils::variables::Variable)s, while [`Parameters`] and the
96    /// [`Cache`] are more like key-value storage accessed by
97    /// [`ParameterID`](crate::resources::ParameterID)s and several different types of cache
98    /// IDs.
99    fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
100
101    /// This method yields the gradient of a particular [`Amplitude`] at a point specified
102    /// by a particular [`Event`] and set of [`Parameters`]. See those structs, as well as
103    /// [`Cache`], for documentation on their available methods. For the most part,
104    /// [`Event`]s can be interacted with via [`Variable`](crate::utils::variables::Variable)s,
105    /// while [`Parameters`] and the [`Cache`] are more like key-value storage accessed by
106    /// [`ParameterID`](crate::resources::ParameterID)s and several different types of cache
107    /// IDs. If the analytic version of the gradient is known, this method can be overwritten to
108    /// improve performance for some derivative-using methods of minimization. The default
109    /// implementation calculates a central finite difference across all parameters, regardless of
110    /// whether or not they are used in the [`Amplitude`].
111    ///
112    /// In the future, it may be possible to automatically implement this with the indices of
113    /// registered free parameters, but until then, the [`Amplitude::central_difference_with_indices`]
114    /// method can be used to conveniently only calculate central differences for the parameters
115    /// which are used by the [`Amplitude`].
116    fn compute_gradient(
117        &self,
118        parameters: &Parameters,
119        event: &Event,
120        cache: &Cache,
121        gradient: &mut DVector<Complex<Float>>,
122    ) {
123        self.central_difference_with_indices(
124            &Vec::from_iter(0..parameters.len()),
125            parameters,
126            event,
127            cache,
128            gradient,
129        )
130    }
131
132    /// A helper function to implement a central difference only on indices which correspond to
133    /// free parameters in the [`Amplitude`]. For example, if an [`Amplitude`] contains free
134    /// parameters registered to indices 1, 3, and 5 of the its internal parameters array, then
135    /// running this with those indices will compute a central finite difference derivative for
136    /// those coordinates only, since the rest can be safely assumed to be zero.
137    fn central_difference_with_indices(
138        &self,
139        indices: &[usize],
140        parameters: &Parameters,
141        event: &Event,
142        cache: &Cache,
143        gradient: &mut DVector<Complex<Float>>,
144    ) {
145        let x = parameters.parameters.to_owned();
146        let constants = parameters.constants.to_owned();
147        let h: DVector<Float> = x
148            .iter()
149            .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
150            .collect::<Vec<_>>()
151            .into();
152        for i in indices {
153            let mut x_plus = x.clone();
154            let mut x_minus = x.clone();
155            x_plus[*i] += h[*i];
156            x_minus[*i] -= h[*i];
157            let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
158            let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
159            gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
160        }
161    }
162}
163
164/// Utility function to calculate a central finite difference gradient.
165pub fn central_difference<F: Fn(&[Float]) -> Float>(
166    parameters: &[Float],
167    func: F,
168) -> DVector<Float> {
169    let mut gradient = DVector::zeros(parameters.len());
170    let x = parameters.to_owned();
171    let h: DVector<Float> = x
172        .iter()
173        .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
174        .collect::<Vec<_>>()
175        .into();
176    for i in 0..parameters.len() {
177        let mut x_plus = x.clone();
178        let mut x_minus = x.clone();
179        x_plus[i] += h[i];
180        x_minus[i] -= h[i];
181        let f_plus = func(&x_plus);
182        let f_minus = func(&x_minus);
183        gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
184    }
185    gradient
186}
187
188dyn_clone::clone_trait_object!(Amplitude);
189
190/// A helper struct that contains the value of each amplitude for a particular event
191#[derive(Debug)]
192pub struct AmplitudeValues(pub Vec<Complex<Float>>);
193
194/// A helper struct that contains the gradient of each amplitude for a particular event
195#[derive(Debug)]
196pub struct GradientValues(pub usize, pub Vec<DVector<Complex<Float>>>);
197
198/// A tag which refers to a registered [`Amplitude`]. This is the base object which can be used to
199/// build [`Expression`]s and should be obtained from the [`Manager::register`] method.
200#[derive(Clone, Default, Debug, Serialize, Deserialize)]
201pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
202
203impl Display for AmplitudeID {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        write!(f, "{}(id={})", self.0, self.1)
206    }
207}
208
209impl From<AmplitudeID> for Expression {
210    fn from(value: AmplitudeID) -> Self {
211        Self::Amp(value)
212    }
213}
214
215/// An expression tree which contains [`AmplitudeID`]s and operators over them.
216#[derive(Clone, Serialize, Deserialize, Default)]
217pub enum Expression {
218    #[default]
219    /// A expression equal to zero.
220    Zero,
221    /// A expression equal to one.
222    One,
223    /// A registered [`Amplitude`] referenced by an [`AmplitudeID`].
224    Amp(AmplitudeID),
225    /// The sum of two [`Expression`]s.
226    Add(Box<Expression>, Box<Expression>),
227    /// The difference of two [`Expression`]s.
228    Sub(Box<Expression>, Box<Expression>),
229    /// The product of two [`Expression`]s.
230    Mul(Box<Expression>, Box<Expression>),
231    /// The division of two [`Expression`]s.
232    Div(Box<Expression>, Box<Expression>),
233    /// The additive inverse of an [`Expression`].
234    Neg(Box<Expression>),
235    /// The real part of an [`Expression`].
236    Real(Box<Expression>),
237    /// The imaginary part of an [`Expression`].
238    Imag(Box<Expression>),
239    /// The complex conjugate of an [`Expression`].
240    Conj(Box<Expression>),
241    /// The absolute square of an [`Expression`].
242    NormSqr(Box<Expression>),
243}
244
245impl Debug for Expression {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        self.write_tree(f, "", "", "")
248    }
249}
250
251impl Display for Expression {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        write!(f, "{:?}", self)
254    }
255}
256
257#[rustfmt::skip]
258impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
259    Expression::Add(Box::new(a.clone()), Box::new(b.clone()))
260});
261#[rustfmt::skip]
262impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
263    Expression::Sub(Box::new(a.clone()), Box::new(b.clone()))
264});
265#[rustfmt::skip]
266impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
267    Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
268});
269#[rustfmt::skip]
270impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
271    Expression::Div(Box::new(a.clone()), Box::new(b.clone()))
272});
273#[rustfmt::skip]
274impl_op_ex!(- |a: &Expression| -> Expression {
275    Expression::Neg(Box::new(a.clone()))
276});
277
278#[rustfmt::skip]
279impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression {
280    Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
281});
282#[rustfmt::skip]
283impl_op_ex_commutative!(- |a: &AmplitudeID, b: &Expression| -> Expression {
284    Expression::Sub(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
285});
286#[rustfmt::skip]
287impl_op_ex_commutative!(* |a: &AmplitudeID, b: &Expression| -> Expression {
288    Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
289});
290#[rustfmt::skip]
291impl_op_ex_commutative!(/ |a: &AmplitudeID, b: &Expression| -> Expression {
292    Expression::Div(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
293});
294
295#[rustfmt::skip]
296impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
297    Expression::Add(
298        Box::new(Expression::Amp(a.clone())),
299        Box::new(Expression::Amp(b.clone()))
300    )
301});
302#[rustfmt::skip]
303impl_op_ex!(- |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
304    Expression::Sub(
305        Box::new(Expression::Amp(a.clone())),
306        Box::new(Expression::Amp(b.clone()))
307    )
308});
309#[rustfmt::skip]
310impl_op_ex!(* |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
311    Expression::Mul(
312        Box::new(Expression::Amp(a.clone())),
313        Box::new(Expression::Amp(b.clone())),
314    )
315});
316#[rustfmt::skip]
317impl_op_ex!(/ |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
318    Expression::Div(
319        Box::new(Expression::Amp(a.clone())),
320        Box::new(Expression::Amp(b.clone())),
321    )
322});
323#[rustfmt::skip]
324impl_op_ex!(- |a: &AmplitudeID| -> Expression {
325    Expression::Neg(
326        Box::new(Expression::Amp(a.clone())),
327    )
328});
329
330impl AmplitudeID {
331    /// Takes the real part of the given [`Amplitude`].
332    pub fn real(&self) -> Expression {
333        Expression::Real(Box::new(Expression::Amp(self.clone())))
334    }
335    /// Takes the imaginary part of the given [`Amplitude`].
336    pub fn imag(&self) -> Expression {
337        Expression::Imag(Box::new(Expression::Amp(self.clone())))
338    }
339    /// Takes the complex conjugate of the given [`Amplitude`].
340    pub fn conj(&self) -> Expression {
341        Expression::Conj(Box::new(Expression::Amp(self.clone())))
342    }
343    /// Takes the absolute square of the given [`Amplitude`].
344    pub fn norm_sqr(&self) -> Expression {
345        Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
346    }
347}
348
349impl Expression {
350    /// Evaluate an [`Expression`] over a single event using calculated [`AmplitudeValues`]
351    ///
352    /// This method parses the underlying [`Expression`] but doesn't actually calculate the values
353    /// from the [`Amplitude`]s themselves.
354    pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
355        match self {
356            Expression::Amp(aid) => amplitude_values.0[aid.1],
357            Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
358            Expression::Sub(a, b) => a.evaluate(amplitude_values) - b.evaluate(amplitude_values),
359            Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
360            Expression::Div(a, b) => a.evaluate(amplitude_values) / b.evaluate(amplitude_values),
361            Expression::Neg(a) => -a.evaluate(amplitude_values),
362            Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
363            Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
364            Expression::Conj(a) => a.evaluate(amplitude_values).conj(),
365            Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
366            Expression::Zero => Complex::ZERO,
367            Expression::One => Complex::ONE,
368        }
369    }
370    /// Evaluate the gradient of an [`Expression`] over a single event using calculated [`AmplitudeValues`]
371    ///
372    /// This method parses the underlying [`Expression`] but doesn't actually calculate the
373    /// gradient from the [`Amplitude`]s themselves.
374    pub fn evaluate_gradient(
375        &self,
376        amplitude_values: &AmplitudeValues,
377        gradient_values: &GradientValues,
378    ) -> DVector<Complex<Float>> {
379        match self {
380            Expression::Amp(aid) => gradient_values.1[aid.1].clone(),
381            Expression::Add(a, b) => {
382                a.evaluate_gradient(amplitude_values, gradient_values)
383                    + b.evaluate_gradient(amplitude_values, gradient_values)
384            }
385            Expression::Sub(a, b) => {
386                a.evaluate_gradient(amplitude_values, gradient_values)
387                    - b.evaluate_gradient(amplitude_values, gradient_values)
388            }
389            Expression::Mul(a, b) => {
390                let f_a = a.evaluate(amplitude_values);
391                let f_b = b.evaluate(amplitude_values);
392                b.evaluate_gradient(amplitude_values, gradient_values)
393                    .map(|g| g * f_a)
394                    + a.evaluate_gradient(amplitude_values, gradient_values)
395                        .map(|g| g * f_b)
396            }
397            Expression::Div(a, b) => {
398                let f_a = a.evaluate(amplitude_values);
399                let f_b = b.evaluate(amplitude_values);
400                (a.evaluate_gradient(amplitude_values, gradient_values)
401                    .map(|g| g * f_b)
402                    - b.evaluate_gradient(amplitude_values, gradient_values)
403                        .map(|g| g * f_a))
404                    / (f_b * f_b)
405            }
406            Expression::Neg(a) => -a.evaluate_gradient(amplitude_values, gradient_values),
407            Expression::Real(a) => a
408                .evaluate_gradient(amplitude_values, gradient_values)
409                .map(|g| Complex::new(g.re, 0.0)),
410            Expression::Imag(a) => a
411                .evaluate_gradient(amplitude_values, gradient_values)
412                .map(|g| Complex::new(g.im, 0.0)),
413            Expression::Conj(a) => a
414                .evaluate_gradient(amplitude_values, gradient_values)
415                .map(|g| g.conj()),
416            Expression::NormSqr(a) => {
417                let conj_f_a = a.evaluate(amplitude_values).conjugate();
418                a.evaluate_gradient(amplitude_values, gradient_values)
419                    .map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
420            }
421            Expression::Zero | Expression::One => DVector::zeros(gradient_values.0),
422        }
423    }
424    /// Takes the real part of the given [`Expression`].
425    pub fn real(&self) -> Self {
426        Self::Real(Box::new(self.clone()))
427    }
428    /// Takes the imaginary part of the given [`Expression`].
429    pub fn imag(&self) -> Self {
430        Self::Imag(Box::new(self.clone()))
431    }
432    /// Takes the complex conjugate of the given [`Expression`].
433    pub fn conj(&self) -> Self {
434        Self::Conj(Box::new(self.clone()))
435    }
436    /// Takes the absolute square of the given [`Expression`].
437    pub fn norm_sqr(&self) -> Self {
438        Self::NormSqr(Box::new(self.clone()))
439    }
440
441    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
442    fn write_tree(
443        &self,
444        f: &mut std::fmt::Formatter<'_>,
445        parent_prefix: &str,
446        immediate_prefix: &str,
447        parent_suffix: &str,
448    ) -> std::fmt::Result {
449        let display_string = match self {
450            Self::Amp(aid) => aid.to_string(),
451            Self::Add(_, _) => "+".to_string(),
452            Self::Sub(_, _) => "-".to_string(),
453            Self::Mul(_, _) => "×".to_string(),
454            Self::Div(_, _) => "÷".to_string(),
455            Self::Neg(_) => "-".to_string(),
456            Self::Real(_) => "Re".to_string(),
457            Self::Imag(_) => "Im".to_string(),
458            Self::Conj(_) => "*".to_string(),
459            Self::NormSqr(_) => "NormSqr".to_string(),
460            Self::Zero => "0".to_string(),
461            Self::One => "1".to_string(),
462        };
463        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
464        match self {
465            Self::Amp(_) | Self::Zero | Self::One => {}
466            Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) | Self::Div(a, b) => {
467                let terms = [a, b];
468                let mut it = terms.iter().peekable();
469                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
470                while let Some(child) = it.next() {
471                    match it.peek() {
472                        Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│  "),
473                        None => child.write_tree(f, &child_prefix, "└─ ", "   "),
474                    }?;
475                }
476            }
477            Self::Neg(a) | Self::Real(a) | Self::Imag(a) | Self::Conj(a) | Self::NormSqr(a) => {
478                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
479                a.write_tree(f, &child_prefix, "└─ ", "   ")?;
480            }
481        }
482        Ok(())
483    }
484}
485
486/// A manager which can be used to register [`Amplitude`]s with [`Resources`]. This structure is
487/// essential to any analysis and should be constructed using the [`Manager::default()`] method.
488#[derive(Default, Clone, Serialize, Deserialize)]
489pub struct Manager {
490    amplitudes: Vec<Box<dyn Amplitude>>,
491    resources: Resources,
492}
493
494impl Manager {
495    /// Get the list of parameter names in the order they appear in the [`Manager`]'s [`Resources`] field.
496    pub fn parameters(&self) -> Vec<String> {
497        self.resources.parameters.iter().cloned().collect()
498    }
499    /// Register the given [`Amplitude`] and return an [`AmplitudeID`] that can be used to build
500    /// [`Expression`]s.
501    ///
502    /// # Errors
503    ///
504    /// The [`Amplitude`]'s name must be unique and not already
505    /// registered, else this will return a [`RegistrationError`][LadduError::RegistrationError].
506    pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
507        let mut amp = amplitude.clone();
508        let aid = amp.register(&mut self.resources)?;
509        self.amplitudes.push(amp);
510        Ok(aid)
511    }
512    /// Turns an [`Expression`] made from registered [`Amplitude`]s into a [`Model`].
513    pub fn model(&self, expression: &Expression) -> Model {
514        Model {
515            manager: self.clone(),
516            expression: expression.clone(),
517        }
518    }
519}
520
521/// A struct which contains a set of registerd [`Amplitude`]s (inside a [`Manager`])
522/// and an [`Expression`].
523///
524/// This struct implements [`serde::Serialize`] and [`serde::Deserialize`] and is intended
525/// to be used to store models to disk.
526#[derive(Clone, Serialize, Deserialize)]
527pub struct Model {
528    pub(crate) manager: Manager,
529    pub(crate) expression: Expression,
530}
531
532impl ReadWrite for Model {
533    fn create_null() -> Self {
534        Model {
535            manager: Manager::default(),
536            expression: Expression::default(),
537        }
538    }
539}
540impl Model {
541    /// Get the list of parameter names in the order they appear in the [`Model`]'s [`Manager`] field.
542    pub fn parameters(&self) -> Vec<String> {
543        self.manager.parameters()
544    }
545    /// Create an [`Evaluator`] which can compute the result of the internal [`Expression`] built on
546    /// registered [`Amplitude`]s over the given [`Dataset`]. This method precomputes any relevant
547    /// information over the [`Event`]s in the [`Dataset`].
548    pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
549        let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
550        loaded_resources.write().reserve_cache(dataset.n_events());
551        for amplitude in &self.manager.amplitudes {
552            amplitude.precompute_all(dataset, &mut loaded_resources.write());
553        }
554        Evaluator {
555            amplitudes: self.manager.amplitudes.clone(),
556            resources: loaded_resources.clone(),
557            dataset: dataset.clone(),
558            expression: self.expression.clone(),
559        }
560    }
561}
562
563/// A structure which can be used to evaluate the stored [`Expression`] built on registered
564/// [`Amplitude`]s. This contains a [`Resources`] struct which already contains cached values for
565/// precomputed [`Amplitude`]s and any relevant free parameters and constants.
566#[derive(Clone)]
567pub struct Evaluator {
568    /// A list of [`Amplitude`]s which were registered with the [`Manager`] used to create the
569    /// internal [`Expression`]. This includes but is not limited to those which are actually used
570    /// in the [`Expression`].
571    pub amplitudes: Vec<Box<dyn Amplitude>>,
572    /// The internal [`Resources`] where precalculated values are stored
573    pub resources: Arc<RwLock<Resources>>,
574    /// The internal [`Dataset`]
575    pub dataset: Arc<Dataset>,
576    /// The internal [`Expression`]
577    pub expression: Expression,
578}
579
580impl Evaluator {
581    /// Get the list of parameter names in the order they appear in the [`Evaluator::evaluate`]
582    /// method.
583    pub fn parameters(&self) -> Vec<String> {
584        self.resources.read().parameters.iter().cloned().collect()
585    }
586    /// Activate an [`Amplitude`] by name.
587    pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
588        self.resources.write().activate(name)
589    }
590    /// Activate several [`Amplitude`]s by name.
591    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
592        self.resources.write().activate_many(names)
593    }
594    /// Activate all registered [`Amplitude`]s.
595    pub fn activate_all(&self) {
596        self.resources.write().activate_all();
597    }
598    /// Dectivate an [`Amplitude`] by name.
599    pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
600        self.resources.write().deactivate(name)
601    }
602    /// Deactivate several [`Amplitude`]s by name.
603    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
604        self.resources.write().deactivate_many(names)
605    }
606    /// Deactivate all registered [`Amplitude`]s.
607    pub fn deactivate_all(&self) {
608        self.resources.write().deactivate_all();
609    }
610    /// Isolate an [`Amplitude`] by name (deactivate the rest).
611    pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
612        self.resources.write().isolate(name)
613    }
614    /// Isolate several [`Amplitude`]s by name (deactivate the rest).
615    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
616        self.resources.write().isolate_many(names)
617    }
618
619    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
620    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
621    ///
622    /// # Notes
623    ///
624    /// This method is not intended to be called in analyses but rather in writing methods
625    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
626    pub fn evaluate_local(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
627        let resources = self.resources.read();
628        let parameters = Parameters::new(parameters, &resources.constants);
629        #[cfg(feature = "rayon")]
630        {
631            let amplitude_values_vec: Vec<AmplitudeValues> = self
632                .dataset
633                .events
634                .par_iter()
635                .zip(resources.caches.par_iter())
636                .map(|(event, cache)| {
637                    AmplitudeValues(
638                        self.amplitudes
639                            .iter()
640                            .zip(resources.active.iter())
641                            .map(|(amp, active)| {
642                                if *active {
643                                    amp.compute(&parameters, event, cache)
644                                } else {
645                                    Complex::new(0.0, 0.0)
646                                }
647                            })
648                            .collect(),
649                    )
650                })
651                .collect();
652            amplitude_values_vec
653                .par_iter()
654                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
655                .collect()
656        }
657        #[cfg(not(feature = "rayon"))]
658        {
659            let amplitude_values_vec: Vec<AmplitudeValues> = self
660                .dataset
661                .events
662                .iter()
663                .zip(resources.caches.iter())
664                .map(|(event, cache)| {
665                    AmplitudeValues(
666                        self.amplitudes
667                            .iter()
668                            .zip(resources.active.iter())
669                            .map(|(amp, active)| {
670                                if *active {
671                                    amp.compute(&parameters, event, cache)
672                                } else {
673                                    Complex::new(0.0, 0.0)
674                                }
675                            })
676                            .collect(),
677                    )
678                })
679                .collect();
680            amplitude_values_vec
681                .iter()
682                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
683                .collect()
684        }
685    }
686
687    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
688    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
689    ///
690    /// # Notes
691    ///
692    /// This method is not intended to be called in analyses but rather in writing methods
693    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
694    #[cfg(feature = "mpi")]
695    fn evaluate_mpi(
696        &self,
697        parameters: &[Float],
698        world: &SimpleCommunicator,
699    ) -> Vec<Complex<Float>> {
700        let local_evaluation = self.evaluate_local(parameters);
701        let n_events = self.dataset.n_events();
702        let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; n_events];
703        let (counts, displs) = world.get_counts_displs(n_events);
704        {
705            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
706            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
707        }
708        buffer
709    }
710
711    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
712    /// [`Evaluator`] with the given values for free parameters.
713    pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
714        #[cfg(feature = "mpi")]
715        {
716            if let Some(world) = crate::mpi::get_world() {
717                return self.evaluate_mpi(parameters, &world);
718            }
719        }
720        self.evaluate_local(parameters)
721    }
722
723    /// See [`Evaluator::evaluate_local`]. This method evaluates over a subset of events rather
724    /// than all events in the total dataset.
725    pub fn evaluate_batch_local(
726        &self,
727        parameters: &[Float],
728        indices: &[usize],
729    ) -> Vec<Complex<Float>> {
730        let resources = self.resources.read();
731        let parameters = Parameters::new(parameters, &resources.constants);
732        #[cfg(feature = "rayon")]
733        {
734            let amplitude_values_vec: Vec<AmplitudeValues> = self
735                .dataset
736                .events
737                .par_iter()
738                .zip(resources.caches.par_iter())
739                .enumerate()
740                .filter_map(|(i, (event, cache))| {
741                    if indices.contains(&i) {
742                        Some((event, cache))
743                    } else {
744                        None
745                    }
746                })
747                .map(|(event, cache)| {
748                    AmplitudeValues(
749                        self.amplitudes
750                            .iter()
751                            .zip(resources.active.iter())
752                            .map(|(amp, active)| {
753                                if *active {
754                                    amp.compute(&parameters, event, cache)
755                                } else {
756                                    Complex::new(0.0, 0.0)
757                                }
758                            })
759                            .collect(),
760                    )
761                })
762                .collect();
763            amplitude_values_vec
764                .par_iter()
765                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
766                .collect()
767        }
768        #[cfg(not(feature = "rayon"))]
769        {
770            let amplitude_values_vec: Vec<AmplitudeValues> = self
771                .dataset
772                .events
773                .iter()
774                .zip(resources.caches.iter())
775                .enumerate()
776                .filter_map(|(i, (event, cache))| {
777                    if indices.contains(&i) {
778                        Some((event, cache))
779                    } else {
780                        None
781                    }
782                })
783                .map(|(event, cache)| {
784                    AmplitudeValues(
785                        self.amplitudes
786                            .iter()
787                            .zip(resources.active.iter())
788                            .map(|(amp, active)| {
789                                if *active {
790                                    amp.compute(&parameters, event, cache)
791                                } else {
792                                    Complex::new(0.0, 0.0)
793                                }
794                            })
795                            .collect(),
796                    )
797                })
798                .collect();
799            amplitude_values_vec
800                .iter()
801                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
802                .collect()
803        }
804    }
805
806    /// See [`Evaluator::evaluate_mpi`]. This method evaluates over a subset of events rather
807    /// than all events in the total dataset.
808    #[cfg(feature = "mpi")]
809    fn evaluate_batch_mpi(
810        &self,
811        parameters: &[Float],
812        indices: &[usize],
813        world: &SimpleCommunicator,
814    ) -> Vec<Complex<Float>> {
815        let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; indices.len()];
816        let (counts, displs, locals) = self
817            .dataset
818            .get_counts_displs_locals_from_indices(indices, world);
819        let local_evaluation = self.evaluate_batch_local(parameters, &locals);
820        {
821            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
822            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
823        }
824        buffer
825    }
826
827    /// Evaluate the stored [`Expression`] over a subset of events in the [`Dataset`] stored by the
828    /// [`Evaluator`] with the given values for free parameters. See also [`Expression::evaluate`].
829    pub fn evaluate_batch(&self, parameters: &[Float], indices: &[usize]) -> Vec<Complex<Float>> {
830        #[cfg(feature = "mpi")]
831        {
832            if let Some(world) = crate::mpi::get_world() {
833                return self.evaluate_batch_mpi(parameters, indices, &world);
834            }
835        }
836        self.evaluate_batch_local(parameters, indices)
837    }
838
839    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
840    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
841    ///
842    /// # Notes
843    ///
844    /// This method is not intended to be called in analyses but rather in writing methods
845    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
846    pub fn evaluate_gradient_local(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
847        let resources = self.resources.read();
848        let parameters = Parameters::new(parameters, &resources.constants);
849        #[cfg(feature = "rayon")]
850        {
851            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
852                .dataset
853                .events
854                .par_iter()
855                .zip(resources.caches.par_iter())
856                .map(|(event, cache)| {
857                    let mut gradient_values =
858                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
859                    self.amplitudes
860                        .iter()
861                        .zip(resources.active.iter())
862                        .zip(gradient_values.iter_mut())
863                        .for_each(|((amp, active), grad)| {
864                            if *active {
865                                amp.compute_gradient(&parameters, event, cache, grad)
866                            }
867                        });
868                    (
869                        AmplitudeValues(
870                            self.amplitudes
871                                .iter()
872                                .zip(resources.active.iter())
873                                .map(|(amp, active)| {
874                                    if *active {
875                                        amp.compute(&parameters, event, cache)
876                                    } else {
877                                        Complex::new(0.0, 0.0)
878                                    }
879                                })
880                                .collect(),
881                        ),
882                        GradientValues(parameters.len(), gradient_values),
883                    )
884                })
885                .collect();
886            amplitude_values_and_gradient_vec
887                .par_iter()
888                .map(|(amplitude_values, gradient_values)| {
889                    self.expression
890                        .evaluate_gradient(amplitude_values, gradient_values)
891                })
892                .collect()
893        }
894        #[cfg(not(feature = "rayon"))]
895        {
896            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
897                .dataset
898                .events
899                .iter()
900                .zip(resources.caches.iter())
901                .map(|(event, cache)| {
902                    let mut gradient_values =
903                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
904                    self.amplitudes
905                        .iter()
906                        .zip(resources.active.iter())
907                        .zip(gradient_values.iter_mut())
908                        .for_each(|((amp, active), grad)| {
909                            if *active {
910                                amp.compute_gradient(&parameters, event, cache, grad)
911                            }
912                        });
913                    (
914                        AmplitudeValues(
915                            self.amplitudes
916                                .iter()
917                                .zip(resources.active.iter())
918                                .map(|(amp, active)| {
919                                    if *active {
920                                        amp.compute(&parameters, event, cache)
921                                    } else {
922                                        Complex::new(0.0, 0.0)
923                                    }
924                                })
925                                .collect(),
926                        ),
927                        GradientValues(parameters.len(), gradient_values),
928                    )
929                })
930                .collect();
931
932            amplitude_values_and_gradient_vec
933                .iter()
934                .map(|(amplitude_values, gradient_values)| {
935                    self.expression
936                        .evaluate_gradient(amplitude_values, gradient_values)
937                })
938                .collect()
939        }
940    }
941
942    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
943    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
944    ///
945    /// # Notes
946    ///
947    /// This method is not intended to be called in analyses but rather in writing methods
948    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
949    #[cfg(feature = "mpi")]
950    fn evaluate_gradient_mpi(
951        &self,
952        parameters: &[Float],
953        world: &SimpleCommunicator,
954    ) -> Vec<DVector<Complex<Float>>> {
955        let flattened_local_evaluation = self
956            .evaluate_gradient_local(parameters)
957            .iter()
958            .flat_map(|g| g.data.as_vec().to_vec())
959            .collect::<Vec<Complex<Float>>>();
960        let n_events = self.dataset.n_events();
961        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
962        let mut flattened_result_buffer = vec![Complex::ZERO; n_events * parameters.len()];
963        let mut partitioned_flattened_result_buffer =
964            PartitionMut::new(&mut flattened_result_buffer, counts, displs);
965        world.all_gather_varcount_into(
966            &flattened_local_evaluation,
967            &mut partitioned_flattened_result_buffer,
968        );
969        flattened_result_buffer
970            .chunks(parameters.len())
971            .map(DVector::from_row_slice)
972            .collect()
973    }
974
975    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
976    /// [`Evaluator`] with the given values for free parameters.
977    pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
978        #[cfg(feature = "mpi")]
979        {
980            if let Some(world) = crate::mpi::get_world() {
981                return self.evaluate_gradient_mpi(parameters, &world);
982            }
983        }
984        self.evaluate_gradient_local(parameters)
985    }
986
987    /// See [`Evaluator::evaluate_gradient_local`]. This method evaluates over a subset
988    /// of events rather than all events in the total dataset.
989    pub fn evaluate_gradient_batch_local(
990        &self,
991        parameters: &[Float],
992        indices: &[usize],
993    ) -> Vec<DVector<Complex<Float>>> {
994        let resources = self.resources.read();
995        let parameters = Parameters::new(parameters, &resources.constants);
996        #[cfg(feature = "rayon")]
997        {
998            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
999                .dataset
1000                .events
1001                .par_iter()
1002                .zip(resources.caches.par_iter())
1003                .enumerate()
1004                .filter_map(|(i, (event, cache))| {
1005                    if indices.contains(&i) {
1006                        Some((event, cache))
1007                    } else {
1008                        None
1009                    }
1010                })
1011                .map(|(event, cache)| {
1012                    let mut gradient_values =
1013                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1014                    self.amplitudes
1015                        .iter()
1016                        .zip(resources.active.iter())
1017                        .zip(gradient_values.iter_mut())
1018                        .for_each(|((amp, active), grad)| {
1019                            if *active {
1020                                amp.compute_gradient(&parameters, event, cache, grad)
1021                            }
1022                        });
1023                    (
1024                        AmplitudeValues(
1025                            self.amplitudes
1026                                .iter()
1027                                .zip(resources.active.iter())
1028                                .map(|(amp, active)| {
1029                                    if *active {
1030                                        amp.compute(&parameters, event, cache)
1031                                    } else {
1032                                        Complex::new(0.0, 0.0)
1033                                    }
1034                                })
1035                                .collect(),
1036                        ),
1037                        GradientValues(parameters.len(), gradient_values),
1038                    )
1039                })
1040                .collect();
1041            amplitude_values_and_gradient_vec
1042                .par_iter()
1043                .map(|(amplitude_values, gradient_values)| {
1044                    self.expression
1045                        .evaluate_gradient(amplitude_values, gradient_values)
1046                })
1047                .collect()
1048        }
1049        #[cfg(not(feature = "rayon"))]
1050        {
1051            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
1052                .dataset
1053                .events
1054                .iter()
1055                .zip(resources.caches.iter())
1056                .enumerate()
1057                .filter_map(|(i, (event, cache))| {
1058                    if indices.contains(&i) {
1059                        Some((event, cache))
1060                    } else {
1061                        None
1062                    }
1063                })
1064                .map(|(event, cache)| {
1065                    let mut gradient_values =
1066                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1067                    self.amplitudes
1068                        .iter()
1069                        .zip(resources.active.iter())
1070                        .zip(gradient_values.iter_mut())
1071                        .for_each(|((amp, active), grad)| {
1072                            if *active {
1073                                amp.compute_gradient(&parameters, event, cache, grad)
1074                            }
1075                        });
1076                    (
1077                        AmplitudeValues(
1078                            self.amplitudes
1079                                .iter()
1080                                .zip(resources.active.iter())
1081                                .map(|(amp, active)| {
1082                                    if *active {
1083                                        amp.compute(&parameters, event, cache)
1084                                    } else {
1085                                        Complex::new(0.0, 0.0)
1086                                    }
1087                                })
1088                                .collect(),
1089                        ),
1090                        GradientValues(parameters.len(), gradient_values),
1091                    )
1092                })
1093                .collect();
1094
1095            amplitude_values_and_gradient_vec
1096                .iter()
1097                .map(|(amplitude_values, gradient_values)| {
1098                    self.expression
1099                        .evaluate_gradient(amplitude_values, gradient_values)
1100                })
1101                .collect()
1102        }
1103    }
1104
1105    /// See [`Evaluator::evaluate_gradient_mpi`]. This method evaluates over a subset
1106    /// of events rather than all events in the total dataset.
1107    #[cfg(feature = "mpi")]
1108    fn evaluate_gradient_batch_mpi(
1109        &self,
1110        parameters: &[Float],
1111        indices: &[usize],
1112        world: &SimpleCommunicator,
1113    ) -> Vec<DVector<Complex<Float>>> {
1114        let (counts, displs, locals) = self
1115            .dataset
1116            .get_flattened_counts_displs_locals_from_indices(indices, parameters.len(), world);
1117        let flattened_local_evaluation = self
1118            .evaluate_gradient_batch_local(parameters, &locals)
1119            .iter()
1120            .flat_map(|g| g.data.as_vec().to_vec())
1121            .collect::<Vec<Complex<Float>>>();
1122        let mut flattened_result_buffer = vec![Complex::ZERO; indices.len() * parameters.len()];
1123        let mut partitioned_flattened_result_buffer =
1124            PartitionMut::new(&mut flattened_result_buffer, counts, displs);
1125        world.all_gather_varcount_into(
1126            &flattened_local_evaluation,
1127            &mut partitioned_flattened_result_buffer,
1128        );
1129        flattened_result_buffer
1130            .chunks(parameters.len())
1131            .map(DVector::from_row_slice)
1132            .collect()
1133    }
1134
1135    /// Evaluate the gradient of the stored [`Expression`] over a subset of the
1136    /// events in the [`Dataset`] stored by the [`Evaluator`] with the given values
1137    /// for free parameters. See also [`Expression::evaluate_gradient`].
1138    pub fn evaluate_gradient_batch(
1139        &self,
1140        parameters: &[Float],
1141        indices: &[usize],
1142    ) -> Vec<DVector<Complex<Float>>> {
1143        #[cfg(feature = "mpi")]
1144        {
1145            if let Some(world) = crate::mpi::get_world() {
1146                return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
1147            }
1148        }
1149        self.evaluate_gradient_batch_local(parameters, indices)
1150    }
1151}
1152
1153/// A testing [`Amplitude`].
1154#[derive(Clone, Serialize, Deserialize)]
1155pub struct TestAmplitude {
1156    name: String,
1157    re: ParameterLike,
1158    pid_re: ParameterID,
1159    im: ParameterLike,
1160    pid_im: ParameterID,
1161}
1162
1163impl TestAmplitude {
1164    /// Create a new testing [`Amplitude`].
1165    pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
1166        Self {
1167            name: name.to_string(),
1168            re,
1169            pid_re: Default::default(),
1170            im,
1171            pid_im: Default::default(),
1172        }
1173        .into()
1174    }
1175}
1176
1177#[typetag::serde]
1178impl Amplitude for TestAmplitude {
1179    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
1180        self.pid_re = resources.register_parameter(&self.re);
1181        self.pid_im = resources.register_parameter(&self.im);
1182        resources.register_amplitude(&self.name)
1183    }
1184
1185    fn compute(&self, parameters: &Parameters, event: &Event, _cache: &Cache) -> Complex<Float> {
1186        Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im)) * event.p4s[0].e()
1187    }
1188
1189    fn compute_gradient(
1190        &self,
1191        _parameters: &Parameters,
1192        event: &Event,
1193        _cache: &Cache,
1194        gradient: &mut DVector<Complex<Float>>,
1195    ) {
1196        if let ParameterID::Parameter(ind) = self.pid_re {
1197            gradient[ind] = Complex::ONE * event.p4s[0].e();
1198        }
1199        if let ParameterID::Parameter(ind) = self.pid_im {
1200            gradient[ind] = Complex::I * event.p4s[0].e();
1201        }
1202    }
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207    use crate::data::{test_dataset, test_event};
1208
1209    use super::*;
1210    use crate::{
1211        data::Event,
1212        resources::{Cache, ParameterID, Parameters, Resources},
1213        Float, LadduError,
1214    };
1215    use approx::assert_relative_eq;
1216    use serde::{Deserialize, Serialize};
1217
1218    #[derive(Clone, Serialize, Deserialize)]
1219    pub struct ComplexScalar {
1220        name: String,
1221        re: ParameterLike,
1222        pid_re: ParameterID,
1223        im: ParameterLike,
1224        pid_im: ParameterID,
1225    }
1226
1227    impl ComplexScalar {
1228        pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
1229            Self {
1230                name: name.to_string(),
1231                re,
1232                pid_re: Default::default(),
1233                im,
1234                pid_im: Default::default(),
1235            }
1236            .into()
1237        }
1238    }
1239
1240    #[typetag::serde]
1241    impl Amplitude for ComplexScalar {
1242        fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
1243            self.pid_re = resources.register_parameter(&self.re);
1244            self.pid_im = resources.register_parameter(&self.im);
1245            resources.register_amplitude(&self.name)
1246        }
1247
1248        fn compute(
1249            &self,
1250            parameters: &Parameters,
1251            _event: &Event,
1252            _cache: &Cache,
1253        ) -> Complex<Float> {
1254            Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
1255        }
1256
1257        fn compute_gradient(
1258            &self,
1259            _parameters: &Parameters,
1260            _event: &Event,
1261            _cache: &Cache,
1262            gradient: &mut DVector<Complex<Float>>,
1263        ) {
1264            if let ParameterID::Parameter(ind) = self.pid_re {
1265                gradient[ind] = Complex::ONE;
1266            }
1267            if let ParameterID::Parameter(ind) = self.pid_im {
1268                gradient[ind] = Complex::I;
1269            }
1270        }
1271    }
1272
1273    #[test]
1274    fn test_batch_evaluation() {
1275        let mut manager = Manager::default();
1276        let amp = TestAmplitude::new("test", parameter("real"), parameter("imag"));
1277        let aid = manager.register(amp).unwrap();
1278        let mut event1 = test_event();
1279        event1.p4s[0].t = 10.0;
1280        let mut event2 = test_event();
1281        event2.p4s[0].t = 11.0;
1282        let mut event3 = test_event();
1283        event3.p4s[0].t = 12.0;
1284        let dataset = Arc::new(Dataset {
1285            events: vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
1286        });
1287        let expr = Expression::Amp(aid);
1288        let model = manager.model(&expr);
1289        let evaluator = model.load(&dataset);
1290        let result = evaluator.evaluate_batch(&[1.1, 2.2], &[0, 2]);
1291        assert_eq!(result.len(), 2);
1292        assert_eq!(result[0], Complex::new(1.1, 2.2) * 10.0);
1293        assert_eq!(result[1], Complex::new(1.1, 2.2) * 12.0);
1294        let result_grad = evaluator.evaluate_gradient_batch(&[1.1, 2.2], &[0, 2]);
1295        assert_eq!(result_grad.len(), 2);
1296        assert_eq!(result_grad[0][0], Complex::new(10.0, 0.0));
1297        assert_eq!(result_grad[0][1], Complex::new(0.0, 10.0));
1298        assert_eq!(result_grad[1][0], Complex::new(12.0, 0.0));
1299        assert_eq!(result_grad[1][1], Complex::new(0.0, 12.0));
1300    }
1301
1302    #[test]
1303    fn test_constant_amplitude() {
1304        let mut manager = Manager::default();
1305        let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
1306        let aid = manager.register(amp).unwrap();
1307        let dataset = Arc::new(Dataset {
1308            events: vec![Arc::new(test_event())],
1309        });
1310        let expr = Expression::Amp(aid);
1311        let model = manager.model(&expr);
1312        let evaluator = model.load(&dataset);
1313        let result = evaluator.evaluate(&[]);
1314        assert_eq!(result[0], Complex::new(2.0, 3.0));
1315    }
1316
1317    #[test]
1318    fn test_parametric_amplitude() {
1319        let mut manager = Manager::default();
1320        let amp = ComplexScalar::new(
1321            "parametric",
1322            parameter("test_param_re"),
1323            parameter("test_param_im"),
1324        );
1325        let aid = manager.register(amp).unwrap();
1326        let dataset = Arc::new(test_dataset());
1327        let expr = Expression::Amp(aid);
1328        let model = manager.model(&expr);
1329        let evaluator = model.load(&dataset);
1330        let result = evaluator.evaluate(&[2.0, 3.0]);
1331        assert_eq!(result[0], Complex::new(2.0, 3.0));
1332    }
1333
1334    #[test]
1335    fn test_expression_operations() {
1336        let mut manager = Manager::default();
1337        let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
1338        let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
1339        let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
1340
1341        let aid1 = manager.register(amp1).unwrap();
1342        let aid2 = manager.register(amp2).unwrap();
1343        let aid3 = manager.register(amp3).unwrap();
1344
1345        let dataset = Arc::new(test_dataset());
1346
1347        // Test (amp) addition
1348        let expr_add = &aid1 + &aid2;
1349        let model_add = manager.model(&expr_add);
1350        let eval_add = model_add.load(&dataset);
1351        let result_add = eval_add.evaluate(&[]);
1352        assert_eq!(result_add[0], Complex::new(2.0, 1.0));
1353
1354        // Test (amp) subtraction
1355        let expr_sub = &aid1 - &aid2;
1356        let model_sub = manager.model(&expr_sub);
1357        let eval_sub = model_sub.load(&dataset);
1358        let result_sub = eval_sub.evaluate(&[]);
1359        assert_eq!(result_sub[0], Complex::new(2.0, -1.0));
1360
1361        // Test (amp) multiplication
1362        let expr_mul = &aid1 * &aid2;
1363        let model_mul = manager.model(&expr_mul);
1364        let eval_mul = model_mul.load(&dataset);
1365        let result_mul = eval_mul.evaluate(&[]);
1366        assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
1367
1368        // Test (amp) division
1369        let expr_div = &aid1 / &aid3;
1370        let model_div = manager.model(&expr_div);
1371        let eval_div = model_div.load(&dataset);
1372        let result_div = eval_div.evaluate(&[]);
1373        assert_eq!(result_div[0], Complex::new(6.0 / 25.0, -8.0 / 25.0));
1374
1375        // Test (amp) neg
1376        let expr_neg = -&aid3;
1377        let model_neg = manager.model(&expr_neg);
1378        let eval_neg = model_neg.load(&dataset);
1379        let result_neg = eval_neg.evaluate(&[]);
1380        assert_eq!(result_neg[0], Complex::new(-3.0, -4.0));
1381
1382        // Test (expr) addition
1383        let expr_add2 = &expr_add + &expr_mul;
1384        let model_add2 = manager.model(&expr_add2);
1385        let eval_add2 = model_add2.load(&dataset);
1386        let result_add2 = eval_add2.evaluate(&[]);
1387        assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
1388
1389        // Test (expr) subtraction
1390        let expr_sub2 = &expr_add - &expr_mul;
1391        let model_sub2 = manager.model(&expr_sub2);
1392        let eval_sub2 = model_sub2.load(&dataset);
1393        let result_sub2 = eval_sub2.evaluate(&[]);
1394        assert_eq!(result_sub2[0], Complex::new(2.0, -1.0));
1395
1396        // Test (expr) multiplication
1397        let expr_mul2 = &expr_add * &expr_mul;
1398        let model_mul2 = manager.model(&expr_mul2);
1399        let eval_mul2 = model_mul2.load(&dataset);
1400        let result_mul2 = eval_mul2.evaluate(&[]);
1401        assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
1402
1403        // Test (expr) division
1404        let expr_div2 = &expr_add / &expr_add2;
1405        let model_div2 = manager.model(&expr_div2);
1406        let eval_div2 = model_div2.load(&dataset);
1407        let result_div2 = eval_div2.evaluate(&[]);
1408        assert_eq!(result_div2[0], Complex::new(7.0 / 13.0, -4.0 / 13.0));
1409
1410        // Test (expr) neg
1411        let expr_neg2 = -&expr_mul2;
1412        let model_neg2 = manager.model(&expr_neg2);
1413        let eval_neg2 = model_neg2.load(&dataset);
1414        let result_neg2 = eval_neg2.evaluate(&[]);
1415        assert_eq!(result_neg2[0], Complex::new(2.0, -4.0));
1416
1417        // Test (amp) real
1418        let expr_real = aid3.real();
1419        let model_real = manager.model(&expr_real);
1420        let eval_real = model_real.load(&dataset);
1421        let result_real = eval_real.evaluate(&[]);
1422        assert_eq!(result_real[0], Complex::new(3.0, 0.0));
1423
1424        // Test (expr) real
1425        let expr_mul2_real = expr_mul2.real();
1426        let model_mul2_real = manager.model(&expr_mul2_real);
1427        let eval_mul2_real = model_mul2_real.load(&dataset);
1428        let result_mul2_real = eval_mul2_real.evaluate(&[]);
1429        assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
1430
1431        // Test (amp) imag
1432        let expr_imag = aid3.imag();
1433        let model_imag = manager.model(&expr_imag);
1434        let eval_imag = model_imag.load(&dataset);
1435        let result_imag = eval_imag.evaluate(&[]);
1436        assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
1437
1438        // Test (expr) imag
1439        let expr_mul2_imag = expr_mul2.imag();
1440        let model_mul2_imag = manager.model(&expr_mul2_imag);
1441        let eval_mul2_imag = model_mul2_imag.load(&dataset);
1442        let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
1443        assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
1444
1445        // Test (amp) conj
1446        let expr_conj = aid3.conj();
1447        let model_conj = manager.model(&expr_conj);
1448        let eval_conj = model_conj.load(&dataset);
1449        let result_conj = eval_conj.evaluate(&[]);
1450        assert_eq!(result_conj[0], Complex::new(3.0, -4.0));
1451
1452        // Test (expr) conj
1453        let expr_mul2_conj = expr_mul2.conj();
1454        let model_mul2_conj = manager.model(&expr_mul2_conj);
1455        let eval_mul2_conj = model_mul2_conj.load(&dataset);
1456        let result_mul2_conj = eval_mul2_conj.evaluate(&[]);
1457        assert_eq!(result_mul2_conj[0], Complex::new(-2.0, -4.0));
1458
1459        // Test (amp) norm_sqr
1460        let expr_norm = aid1.norm_sqr();
1461        let model_norm = manager.model(&expr_norm);
1462        let eval_norm = model_norm.load(&dataset);
1463        let result_norm = eval_norm.evaluate(&[]);
1464        assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
1465
1466        // Test (expr) norm_sqr
1467        let expr_mul2_norm = expr_mul2.norm_sqr();
1468        let model_mul2_norm = manager.model(&expr_mul2_norm);
1469        let eval_mul2_norm = model_mul2_norm.load(&dataset);
1470        let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
1471        assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
1472    }
1473
1474    #[test]
1475    fn test_amplitude_activation() {
1476        let mut manager = Manager::default();
1477        let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
1478        let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
1479
1480        let aid1 = manager.register(amp1).unwrap();
1481        let aid2 = manager.register(amp2).unwrap();
1482
1483        let dataset = Arc::new(test_dataset());
1484        let expr = &aid1 + &aid2;
1485        let model = manager.model(&expr);
1486        let evaluator = model.load(&dataset);
1487
1488        // Test initial state (all active)
1489        let result = evaluator.evaluate(&[]);
1490        assert_eq!(result[0], Complex::new(3.0, 0.0));
1491
1492        // Test deactivation
1493        evaluator.deactivate("const1").unwrap();
1494        let result = evaluator.evaluate(&[]);
1495        assert_eq!(result[0], Complex::new(2.0, 0.0));
1496
1497        // Test isolation
1498        evaluator.isolate("const1").unwrap();
1499        let result = evaluator.evaluate(&[]);
1500        assert_eq!(result[0], Complex::new(1.0, 0.0));
1501
1502        // Test reactivation
1503        evaluator.activate_all();
1504        let result = evaluator.evaluate(&[]);
1505        assert_eq!(result[0], Complex::new(3.0, 0.0));
1506    }
1507
1508    #[test]
1509    fn test_gradient() {
1510        let mut manager = Manager::default();
1511        let amp1 = ComplexScalar::new(
1512            "parametric_1",
1513            parameter("test_param_re_1"),
1514            parameter("test_param_im_1"),
1515        );
1516        let amp2 = ComplexScalar::new(
1517            "parametric_2",
1518            parameter("test_param_re_2"),
1519            parameter("test_param_im_2"),
1520        );
1521
1522        let aid1 = manager.register(amp1).unwrap();
1523        let aid2 = manager.register(amp2).unwrap();
1524        let dataset = Arc::new(test_dataset());
1525        let params = vec![2.0, 3.0, 4.0, 5.0];
1526
1527        let expr = &aid1 + &aid2;
1528        let model = manager.model(&expr);
1529        let evaluator = model.load(&dataset);
1530
1531        let gradient = evaluator.evaluate_gradient(&params);
1532
1533        assert_relative_eq!(gradient[0][0].re, 1.0);
1534        assert_relative_eq!(gradient[0][0].im, 0.0);
1535        assert_relative_eq!(gradient[0][1].re, 0.0);
1536        assert_relative_eq!(gradient[0][1].im, 1.0);
1537        assert_relative_eq!(gradient[0][2].re, 1.0);
1538        assert_relative_eq!(gradient[0][2].im, 0.0);
1539        assert_relative_eq!(gradient[0][3].re, 0.0);
1540        assert_relative_eq!(gradient[0][3].im, 1.0);
1541
1542        let expr = &aid1 - &aid2;
1543        let model = manager.model(&expr);
1544        let evaluator = model.load(&dataset);
1545
1546        let gradient = evaluator.evaluate_gradient(&params);
1547
1548        assert_relative_eq!(gradient[0][0].re, 1.0);
1549        assert_relative_eq!(gradient[0][0].im, 0.0);
1550        assert_relative_eq!(gradient[0][1].re, 0.0);
1551        assert_relative_eq!(gradient[0][1].im, 1.0);
1552        assert_relative_eq!(gradient[0][2].re, -1.0);
1553        assert_relative_eq!(gradient[0][2].im, 0.0);
1554        assert_relative_eq!(gradient[0][3].re, 0.0);
1555        assert_relative_eq!(gradient[0][3].im, -1.0);
1556
1557        let expr = &aid1 * &aid2;
1558        let model = manager.model(&expr);
1559        let evaluator = model.load(&dataset);
1560
1561        let gradient = evaluator.evaluate_gradient(&params);
1562
1563        assert_relative_eq!(gradient[0][0].re, 4.0);
1564        assert_relative_eq!(gradient[0][0].im, 5.0);
1565        assert_relative_eq!(gradient[0][1].re, -5.0);
1566        assert_relative_eq!(gradient[0][1].im, 4.0);
1567        assert_relative_eq!(gradient[0][2].re, 2.0);
1568        assert_relative_eq!(gradient[0][2].im, 3.0);
1569        assert_relative_eq!(gradient[0][3].re, -3.0);
1570        assert_relative_eq!(gradient[0][3].im, 2.0);
1571
1572        let expr = &aid1 / &aid2;
1573        let model = manager.model(&expr);
1574        let evaluator = model.load(&dataset);
1575
1576        let gradient = evaluator.evaluate_gradient(&params);
1577
1578        assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
1579        assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
1580        assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
1581        assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
1582        assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
1583        assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
1584        assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
1585        assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
1586
1587        let expr = -(&aid1 * &aid2);
1588        let model = manager.model(&expr);
1589        let evaluator = model.load(&dataset);
1590
1591        let gradient = evaluator.evaluate_gradient(&params);
1592
1593        assert_relative_eq!(gradient[0][0].re, -4.0);
1594        assert_relative_eq!(gradient[0][0].im, -5.0);
1595        assert_relative_eq!(gradient[0][1].re, 5.0);
1596        assert_relative_eq!(gradient[0][1].im, -4.0);
1597        assert_relative_eq!(gradient[0][2].re, -2.0);
1598        assert_relative_eq!(gradient[0][2].im, -3.0);
1599        assert_relative_eq!(gradient[0][3].re, 3.0);
1600        assert_relative_eq!(gradient[0][3].im, -2.0);
1601
1602        let expr = (&aid1 * &aid2).real();
1603        let model = manager.model(&expr);
1604        let evaluator = model.load(&dataset);
1605
1606        let gradient = evaluator.evaluate_gradient(&params);
1607
1608        assert_relative_eq!(gradient[0][0].re, 4.0);
1609        assert_relative_eq!(gradient[0][0].im, 0.0);
1610        assert_relative_eq!(gradient[0][1].re, -5.0);
1611        assert_relative_eq!(gradient[0][1].im, 0.0);
1612        assert_relative_eq!(gradient[0][2].re, 2.0);
1613        assert_relative_eq!(gradient[0][2].im, 0.0);
1614        assert_relative_eq!(gradient[0][3].re, -3.0);
1615        assert_relative_eq!(gradient[0][3].im, 0.0);
1616
1617        let expr = (&aid1 * &aid2).imag();
1618        let model = manager.model(&expr);
1619        let evaluator = model.load(&dataset);
1620
1621        let gradient = evaluator.evaluate_gradient(&params);
1622
1623        assert_relative_eq!(gradient[0][0].re, 5.0);
1624        assert_relative_eq!(gradient[0][0].im, 0.0);
1625        assert_relative_eq!(gradient[0][1].re, 4.0);
1626        assert_relative_eq!(gradient[0][1].im, 0.0);
1627        assert_relative_eq!(gradient[0][2].re, 3.0);
1628        assert_relative_eq!(gradient[0][2].im, 0.0);
1629        assert_relative_eq!(gradient[0][3].re, 2.0);
1630        assert_relative_eq!(gradient[0][3].im, 0.0);
1631
1632        let expr = (&aid1 * &aid2).conj();
1633        let model = manager.model(&expr);
1634        let evaluator = model.load(&dataset);
1635
1636        let gradient = evaluator.evaluate_gradient(&params);
1637
1638        assert_relative_eq!(gradient[0][0].re, 4.0);
1639        assert_relative_eq!(gradient[0][0].im, -5.0);
1640        assert_relative_eq!(gradient[0][1].re, -5.0);
1641        assert_relative_eq!(gradient[0][1].im, -4.0);
1642        assert_relative_eq!(gradient[0][2].re, 2.0);
1643        assert_relative_eq!(gradient[0][2].im, -3.0);
1644        assert_relative_eq!(gradient[0][3].re, -3.0);
1645        assert_relative_eq!(gradient[0][3].im, -2.0);
1646
1647        let expr = (&aid1 * &aid2).norm_sqr();
1648        let model = manager.model(&expr);
1649        let evaluator = model.load(&dataset);
1650
1651        let gradient = evaluator.evaluate_gradient(&params);
1652
1653        assert_relative_eq!(gradient[0][0].re, 164.0);
1654        assert_relative_eq!(gradient[0][0].im, 0.0);
1655        assert_relative_eq!(gradient[0][1].re, 246.0);
1656        assert_relative_eq!(gradient[0][1].im, 0.0);
1657        assert_relative_eq!(gradient[0][2].re, 104.0);
1658        assert_relative_eq!(gradient[0][2].im, 0.0);
1659        assert_relative_eq!(gradient[0][3].re, 130.0);
1660        assert_relative_eq!(gradient[0][3].im, 0.0);
1661    }
1662
1663    #[test]
1664    fn test_zeros_and_ones() {
1665        let mut manager = Manager::default();
1666        let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1667        let aid = manager.register(amp).unwrap();
1668        let dataset = Arc::new(test_dataset());
1669        let expr = (aid * Expression::One + Expression::Zero).norm_sqr();
1670        let model = manager.model(&expr);
1671        let evaluator = model.load(&dataset);
1672
1673        let params = vec![2.0];
1674        let value = evaluator.evaluate(&params);
1675        let gradient = evaluator.evaluate_gradient(&params);
1676
1677        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the value should be x^2 + 4
1678        assert_relative_eq!(value[0].re, 8.0);
1679        assert_relative_eq!(value[0].im, 0.0);
1680
1681        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the derivative should be 2x
1682        assert_relative_eq!(gradient[0][0].re, 4.0);
1683        assert_relative_eq!(gradient[0][0].im, 0.0);
1684    }
1685
1686    #[test]
1687    fn test_parameter_registration() {
1688        let mut manager = Manager::default();
1689        let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1690
1691        let aid = manager.register(amp).unwrap();
1692        let parameters = manager.parameters();
1693        let model = manager.model(&aid.into());
1694        let model_parameters = model.parameters();
1695        assert_eq!(parameters.len(), 1);
1696        assert_eq!(parameters[0], "test_param_re");
1697        assert_eq!(model_parameters.len(), 1);
1698        assert_eq!(model_parameters[0], "test_param_re");
1699    }
1700
1701    #[test]
1702    fn test_duplicate_amplitude_registration() {
1703        let mut manager = Manager::default();
1704        let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
1705        let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
1706        manager.register(amp1).unwrap();
1707        assert!(manager.register(amp2).is_err());
1708    }
1709
1710    #[test]
1711    fn test_tree_printing() {
1712        let mut manager = Manager::default();
1713        let amp1 = ComplexScalar::new(
1714            "parametric_1",
1715            parameter("test_param_re_1"),
1716            parameter("test_param_im_1"),
1717        );
1718        let amp2 = ComplexScalar::new(
1719            "parametric_2",
1720            parameter("test_param_re_2"),
1721            parameter("test_param_im_2"),
1722        );
1723        let aid1 = manager.register(amp1).unwrap();
1724        let aid2 = manager.register(amp2).unwrap();
1725        let expr = &aid1.real() + &aid2.conj().imag() + Expression::One * -Expression::Zero
1726            - Expression::Zero / Expression::One
1727            + (&aid1 * &aid2).norm_sqr();
1728        assert_eq!(
1729            expr.to_string(),
1730            "+
1731├─ -
1732│  ├─ +
1733│  │  ├─ +
1734│  │  │  ├─ Re
1735│  │  │  │  └─ parametric_1(id=0)
1736│  │  │  └─ Im
1737│  │  │     └─ *
1738│  │  │        └─ parametric_2(id=1)
1739│  │  └─ ×
1740│  │     ├─ 1
1741│  │     └─ -
1742│  │        └─ 0
1743│  └─ ÷
1744│     ├─ 0
1745│     └─ 1
1746└─ NormSqr
1747   └─ ×
1748      ├─ parametric_1(id=0)
1749      └─ parametric_2(id=1)
1750"
1751        );
1752    }
1753}