laddu_core/
amplitudes.rs

1use std::{
2    fmt::{Debug, Display},
3    sync::Arc,
4};
5
6use auto_ops::*;
7use dyn_clone::DynClone;
8use nalgebra::{ComplexField, DVector};
9use num::Complex;
10
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use crate::{
17    data::{Dataset, Event},
18    resources::{Cache, Parameters, Resources},
19    Float, LadduError,
20};
21
22#[cfg(feature = "mpi")]
23use crate::mpi::LadduMPI;
24
25#[cfg(feature = "mpi")]
26use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
27
28/// An enum containing either a named free parameter or a constant value.
29#[derive(Clone, Default, Serialize, Deserialize)]
30pub enum ParameterLike {
31    /// A named free parameter.
32    Parameter(String),
33    /// A constant value.
34    Constant(Float),
35    /// An uninitialized parameter-like structure (typically used as the value given in an
36    /// [`Amplitude`] constructor before the [`Amplitude::register`] method is called).
37    #[default]
38    Uninit,
39}
40
41/// Shorthand for generating a named free parameter.
42pub fn parameter(name: &str) -> ParameterLike {
43    ParameterLike::Parameter(name.to_string())
44}
45
46/// Shorthand for generating a constant value (which acts like a fixed parameter).
47pub fn constant(value: Float) -> ParameterLike {
48    ParameterLike::Constant(value)
49}
50
51/// This is the only required trait for writing new amplitude-like structures for this
52/// crate. Users need only implement the [`register`](Amplitude::register)
53/// method to register parameters, cached values, and the amplitude itself with an input
54/// [`Resources`] struct and the [`compute`](Amplitude::compute) method to actually carry
55/// out the calculation. [`Amplitude`]-implementors are required to implement [`Clone`] and can
56/// optionally implement a [`precompute`](Amplitude::precompute) method to calculate and
57/// cache values which do not depend on free parameters.
58#[typetag::serde(tag = "type")]
59pub trait Amplitude: DynClone + Send + Sync {
60    /// This method should be used to tell the [`Resources`] manager about all of
61    /// the free parameters and cached values used by this [`Amplitude`]. It should end by
62    /// returning an [`AmplitudeID`], which can be obtained from the
63    /// [`Resources::register_amplitude`] method.
64    fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
65    /// This method can be used to do some critical calculations ahead of time and
66    /// store them in a [`Cache`]. These values can only depend on the data in an [`Event`],
67    /// not on any free parameters in the fit. This method is opt-in since it is not required
68    /// to make a functioning [`Amplitude`].
69    #[allow(unused_variables)]
70    fn precompute(&self, event: &Event, cache: &mut Cache) {}
71    /// Evaluates [`Amplitude::precompute`] over ever [`Event`] in a [`Dataset`].
72    #[cfg(feature = "rayon")]
73    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
74        dataset
75            .events
76            .par_iter()
77            .zip(resources.caches.par_iter_mut())
78            .for_each(|(event, cache)| {
79                self.precompute(event, cache);
80            })
81    }
82    /// Evaluates [`Amplitude::precompute`] over ever [`Event`] in a [`Dataset`].
83    #[cfg(not(feature = "rayon"))]
84    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
85        dataset
86            .events
87            .iter()
88            .zip(resources.caches.iter_mut())
89            .for_each(|(event, cache)| self.precompute(event, cache))
90    }
91    /// This method constitutes the main machinery of an [`Amplitude`], returning the actual
92    /// calculated value for a particular [`Event`] and set of [`Parameters`]. See those
93    /// structs, as well as [`Cache`], for documentation on their available methods. For the
94    /// most part, [`Event`]s can be interacted with via
95    /// [`Variable`](crate::utils::variables::Variable)s, while [`Parameters`] and the
96    /// [`Cache`] are more like key-value storage accessed by
97    /// [`ParameterID`](crate::resources::ParameterID)s and several different types of cache
98    /// IDs.
99    fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
100
101    /// This method yields the gradient of a particular [`Amplitude`] at a point specified
102    /// by a particular [`Event`] and set of [`Parameters`]. See those structs, as well as
103    /// [`Cache`], for documentation on their available methods. For the most part,
104    /// [`Event`]s can be interacted with via [`Variable`](crate::utils::variables::Variable)s,
105    /// while [`Parameters`] and the [`Cache`] are more like key-value storage accessed by
106    /// [`ParameterID`](crate::resources::ParameterID)s and several different types of cache
107    /// IDs. If the analytic version of the gradient is known, this method can be overwritten to
108    /// improve performance for some derivative-using methods of minimization. The default
109    /// implementation calculates a central finite difference across all parameters, regardless of
110    /// whether or not they are used in the [`Amplitude`].
111    ///
112    /// In the future, it may be possible to automatically implement this with the indices of
113    /// registered free parameters, but until then, the [`Amplitude::central_difference_with_indices`]
114    /// method can be used to conveniently only calculate central differences for the parameters
115    /// which are used by the [`Amplitude`].
116    fn compute_gradient(
117        &self,
118        parameters: &Parameters,
119        event: &Event,
120        cache: &Cache,
121        gradient: &mut DVector<Complex<Float>>,
122    ) {
123        self.central_difference_with_indices(
124            &Vec::from_iter(0..parameters.len()),
125            parameters,
126            event,
127            cache,
128            gradient,
129        )
130    }
131
132    /// A helper function to implement a central difference only on indices which correspond to
133    /// free parameters in the [`Amplitude`]. For example, if an [`Amplitude`] contains free
134    /// parameters registered to indices 1, 3, and 5 of the its internal parameters array, then
135    /// running this with those indices will compute a central finite difference derivative for
136    /// those coordinates only, since the rest can be safely assumed to be zero.
137    fn central_difference_with_indices(
138        &self,
139        indices: &[usize],
140        parameters: &Parameters,
141        event: &Event,
142        cache: &Cache,
143        gradient: &mut DVector<Complex<Float>>,
144    ) {
145        let x = parameters.parameters.to_owned();
146        let constants = parameters.constants.to_owned();
147        let h: DVector<Float> = x
148            .iter()
149            .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
150            .collect::<Vec<_>>()
151            .into();
152        for i in indices {
153            let mut x_plus = x.clone();
154            let mut x_minus = x.clone();
155            x_plus[*i] += h[*i];
156            x_minus[*i] -= h[*i];
157            let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
158            let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
159            gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
160        }
161    }
162}
163
164/// Utility function to calculate a central finite difference gradient.
165pub fn central_difference<F: Fn(&[Float]) -> Float>(
166    parameters: &[Float],
167    func: F,
168) -> DVector<Float> {
169    let mut gradient = DVector::zeros(parameters.len());
170    let x = parameters.to_owned();
171    let h: DVector<Float> = x
172        .iter()
173        .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
174        .collect::<Vec<_>>()
175        .into();
176    for i in 0..parameters.len() {
177        let mut x_plus = x.clone();
178        let mut x_minus = x.clone();
179        x_plus[i] += h[i];
180        x_minus[i] -= h[i];
181        let f_plus = func(&x_plus);
182        let f_minus = func(&x_minus);
183        gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
184    }
185    gradient
186}
187
188dyn_clone::clone_trait_object!(Amplitude);
189
190/// A helper struct that contains the value of each amplitude for a particular event
191#[derive(Debug)]
192pub struct AmplitudeValues(pub Vec<Complex<Float>>);
193
194/// A helper struct that contains the gradient of each amplitude for a particular event
195#[derive(Debug)]
196pub struct GradientValues(pub usize, pub Vec<DVector<Complex<Float>>>);
197
198/// A tag which refers to a registered [`Amplitude`]. This is the base object which can be used to
199/// build [`Expression`]s and should be obtained from the [`Manager::register`] method.
200#[derive(Clone, Default, Debug, Serialize, Deserialize)]
201pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
202
203impl Display for AmplitudeID {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        write!(f, "{}(id={})", self.0, self.1)
206    }
207}
208
209impl From<AmplitudeID> for Expression {
210    fn from(value: AmplitudeID) -> Self {
211        Self::Amp(value)
212    }
213}
214
215/// An expression tree which contains [`AmplitudeID`]s and operators over them.
216#[derive(Clone, Serialize, Deserialize, Default)]
217pub enum Expression {
218    #[default]
219    /// A expression equal to zero.
220    Zero,
221    /// A expression equal to one.
222    One,
223    /// A registered [`Amplitude`] referenced by an [`AmplitudeID`].
224    Amp(AmplitudeID),
225    /// The sum of two [`Expression`]s.
226    Add(Box<Expression>, Box<Expression>),
227    /// The difference of two [`Expression`]s.
228    Sub(Box<Expression>, Box<Expression>),
229    /// The product of two [`Expression`]s.
230    Mul(Box<Expression>, Box<Expression>),
231    /// The division of two [`Expression`]s.
232    Div(Box<Expression>, Box<Expression>),
233    /// The additive inverse of an [`Expression`].
234    Neg(Box<Expression>),
235    /// The real part of an [`Expression`].
236    Real(Box<Expression>),
237    /// The imaginary part of an [`Expression`].
238    Imag(Box<Expression>),
239    /// The complex conjugate of an [`Expression`].
240    Conj(Box<Expression>),
241    /// The absolute square of an [`Expression`].
242    NormSqr(Box<Expression>),
243}
244
245impl Debug for Expression {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        self.write_tree(f, "", "", "")
248    }
249}
250
251impl Display for Expression {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        write!(f, "{:?}", self)
254    }
255}
256
257#[rustfmt::skip]
258impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
259    Expression::Add(Box::new(a.clone()), Box::new(b.clone()))
260});
261#[rustfmt::skip]
262impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
263    Expression::Sub(Box::new(a.clone()), Box::new(b.clone()))
264});
265#[rustfmt::skip]
266impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
267    Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
268});
269#[rustfmt::skip]
270impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
271    Expression::Div(Box::new(a.clone()), Box::new(b.clone()))
272});
273#[rustfmt::skip]
274impl_op_ex!(- |a: &Expression| -> Expression {
275    Expression::Neg(Box::new(a.clone()))
276});
277
278#[rustfmt::skip]
279impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression {
280    Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
281});
282#[rustfmt::skip]
283impl_op_ex_commutative!(- |a: &AmplitudeID, b: &Expression| -> Expression {
284    Expression::Sub(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
285});
286#[rustfmt::skip]
287impl_op_ex_commutative!(* |a: &AmplitudeID, b: &Expression| -> Expression {
288    Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
289});
290#[rustfmt::skip]
291impl_op_ex_commutative!(/ |a: &AmplitudeID, b: &Expression| -> Expression {
292    Expression::Div(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
293});
294
295#[rustfmt::skip]
296impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
297    Expression::Add(
298        Box::new(Expression::Amp(a.clone())),
299        Box::new(Expression::Amp(b.clone()))
300    )
301});
302#[rustfmt::skip]
303impl_op_ex!(- |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
304    Expression::Sub(
305        Box::new(Expression::Amp(a.clone())),
306        Box::new(Expression::Amp(b.clone()))
307    )
308});
309#[rustfmt::skip]
310impl_op_ex!(* |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
311    Expression::Mul(
312        Box::new(Expression::Amp(a.clone())),
313        Box::new(Expression::Amp(b.clone())),
314    )
315});
316#[rustfmt::skip]
317impl_op_ex!(/ |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
318    Expression::Div(
319        Box::new(Expression::Amp(a.clone())),
320        Box::new(Expression::Amp(b.clone())),
321    )
322});
323#[rustfmt::skip]
324impl_op_ex!(- |a: &AmplitudeID| -> Expression {
325    Expression::Neg(
326        Box::new(Expression::Amp(a.clone())),
327    )
328});
329
330impl AmplitudeID {
331    /// Takes the real part of the given [`Amplitude`].
332    pub fn real(&self) -> Expression {
333        Expression::Real(Box::new(Expression::Amp(self.clone())))
334    }
335    /// Takes the imaginary part of the given [`Amplitude`].
336    pub fn imag(&self) -> Expression {
337        Expression::Imag(Box::new(Expression::Amp(self.clone())))
338    }
339    /// Takes the complex conjugate of the given [`Amplitude`].
340    pub fn conj(&self) -> Expression {
341        Expression::Conj(Box::new(Expression::Amp(self.clone())))
342    }
343    /// Takes the absolute square of the given [`Amplitude`].
344    pub fn norm_sqr(&self) -> Expression {
345        Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
346    }
347}
348
349impl Expression {
350    /// Evaluate an [`Expression`] over a single event using calculated [`AmplitudeValues`]
351    ///
352    /// This method parses the underlying [`Expression`] but doesn't actually calculate the values
353    /// from the [`Amplitude`]s themselves.
354    pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
355        match self {
356            Expression::Amp(aid) => amplitude_values.0[aid.1],
357            Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
358            Expression::Sub(a, b) => a.evaluate(amplitude_values) - b.evaluate(amplitude_values),
359            Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
360            Expression::Div(a, b) => a.evaluate(amplitude_values) / b.evaluate(amplitude_values),
361            Expression::Neg(a) => -a.evaluate(amplitude_values),
362            Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
363            Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
364            Expression::Conj(a) => a.evaluate(amplitude_values).conj(),
365            Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
366            Expression::Zero => Complex::ZERO,
367            Expression::One => Complex::ONE,
368        }
369    }
370    /// Evaluate the gradient of an [`Expression`] over a single event using calculated [`AmplitudeValues`]
371    ///
372    /// This method parses the underlying [`Expression`] but doesn't actually calculate the
373    /// gradient from the [`Amplitude`]s themselves.
374    pub fn evaluate_gradient(
375        &self,
376        amplitude_values: &AmplitudeValues,
377        gradient_values: &GradientValues,
378    ) -> DVector<Complex<Float>> {
379        match self {
380            Expression::Amp(aid) => gradient_values.1[aid.1].clone(),
381            Expression::Add(a, b) => {
382                a.evaluate_gradient(amplitude_values, gradient_values)
383                    + b.evaluate_gradient(amplitude_values, gradient_values)
384            }
385            Expression::Sub(a, b) => {
386                a.evaluate_gradient(amplitude_values, gradient_values)
387                    - b.evaluate_gradient(amplitude_values, gradient_values)
388            }
389            Expression::Mul(a, b) => {
390                let f_a = a.evaluate(amplitude_values);
391                let f_b = b.evaluate(amplitude_values);
392                b.evaluate_gradient(amplitude_values, gradient_values)
393                    .map(|g| g * f_a)
394                    + a.evaluate_gradient(amplitude_values, gradient_values)
395                        .map(|g| g * f_b)
396            }
397            Expression::Div(a, b) => {
398                let f_a = a.evaluate(amplitude_values);
399                let f_b = b.evaluate(amplitude_values);
400                (a.evaluate_gradient(amplitude_values, gradient_values)
401                    .map(|g| g * f_b)
402                    - b.evaluate_gradient(amplitude_values, gradient_values)
403                        .map(|g| g * f_a))
404                    / (f_b * f_b)
405            }
406            Expression::Neg(a) => -a.evaluate_gradient(amplitude_values, gradient_values),
407            Expression::Real(a) => a
408                .evaluate_gradient(amplitude_values, gradient_values)
409                .map(|g| Complex::new(g.re, 0.0)),
410            Expression::Imag(a) => a
411                .evaluate_gradient(amplitude_values, gradient_values)
412                .map(|g| Complex::new(g.im, 0.0)),
413            Expression::Conj(a) => a
414                .evaluate_gradient(amplitude_values, gradient_values)
415                .map(|g| g.conj()),
416            Expression::NormSqr(a) => {
417                let conj_f_a = a.evaluate(amplitude_values).conjugate();
418                a.evaluate_gradient(amplitude_values, gradient_values)
419                    .map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
420            }
421            Expression::Zero | Expression::One => DVector::zeros(gradient_values.0),
422        }
423    }
424    /// Takes the real part of the given [`Expression`].
425    pub fn real(&self) -> Self {
426        Self::Real(Box::new(self.clone()))
427    }
428    /// Takes the imaginary part of the given [`Expression`].
429    pub fn imag(&self) -> Self {
430        Self::Imag(Box::new(self.clone()))
431    }
432    /// Takes the complex conjugate of the given [`Expression`].
433    pub fn conj(&self) -> Self {
434        Self::Conj(Box::new(self.clone()))
435    }
436    /// Takes the absolute square of the given [`Expression`].
437    pub fn norm_sqr(&self) -> Self {
438        Self::NormSqr(Box::new(self.clone()))
439    }
440
441    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
442    fn write_tree(
443        &self,
444        f: &mut std::fmt::Formatter<'_>,
445        parent_prefix: &str,
446        immediate_prefix: &str,
447        parent_suffix: &str,
448    ) -> std::fmt::Result {
449        let display_string = match self {
450            Self::Amp(aid) => aid.to_string(),
451            Self::Add(_, _) => "+".to_string(),
452            Self::Sub(_, _) => "-".to_string(),
453            Self::Mul(_, _) => "×".to_string(),
454            Self::Div(_, _) => "÷".to_string(),
455            Self::Neg(_) => "-".to_string(),
456            Self::Real(_) => "Re".to_string(),
457            Self::Imag(_) => "Im".to_string(),
458            Self::Conj(_) => "*".to_string(),
459            Self::NormSqr(_) => "NormSqr".to_string(),
460            Self::Zero => "0".to_string(),
461            Self::One => "1".to_string(),
462        };
463        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
464        match self {
465            Self::Amp(_) | Self::Zero | Self::One => {}
466            Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) | Self::Div(a, b) => {
467                let terms = [a, b];
468                let mut it = terms.iter().peekable();
469                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
470                while let Some(child) = it.next() {
471                    match it.peek() {
472                        Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│  "),
473                        None => child.write_tree(f, &child_prefix, "└─ ", "   "),
474                    }?;
475                }
476            }
477            Self::Neg(a) | Self::Real(a) | Self::Imag(a) | Self::Conj(a) | Self::NormSqr(a) => {
478                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
479                a.write_tree(f, &child_prefix, "└─ ", "   ")?;
480            }
481        }
482        Ok(())
483    }
484}
485
486/// A manager which can be used to register [`Amplitude`]s with [`Resources`]. This structure is
487/// essential to any analysis and should be constructed using the [`Manager::default()`] method.
488#[derive(Default, Clone, Serialize, Deserialize)]
489pub struct Manager {
490    amplitudes: Vec<Box<dyn Amplitude>>,
491    resources: Resources,
492}
493
494impl Manager {
495    /// Get the list of parameter names in the order they appear in the [`Manager`]'s [`Resources`] field.
496    pub fn parameters(&self) -> Vec<String> {
497        self.resources.parameters.iter().cloned().collect()
498    }
499    /// Register the given [`Amplitude`] and return an [`AmplitudeID`] that can be used to build
500    /// [`Expression`]s.
501    ///
502    /// # Errors
503    ///
504    /// The [`Amplitude`]'s name must be unique and not already
505    /// registered, else this will return a [`RegistrationError`][LadduError::RegistrationError].
506    pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
507        let mut amp = amplitude.clone();
508        let aid = amp.register(&mut self.resources)?;
509        self.amplitudes.push(amp);
510        Ok(aid)
511    }
512    /// Turns an [`Expression`] made from registered [`Amplitude`]s into a [`Model`].
513    pub fn model(&self, expression: &Expression) -> Model {
514        Model {
515            manager: self.clone(),
516            expression: expression.clone(),
517        }
518    }
519}
520
521/// A struct which contains a set of registerd [`Amplitude`]s (inside a [`Manager`])
522/// and an [`Expression`].
523///
524/// This struct implements [`serde::Serialize`] and [`serde::Deserialize`] and is intended
525/// to be used to store models to disk.
526#[derive(Clone, Serialize, Deserialize)]
527pub struct Model {
528    pub(crate) manager: Manager,
529    pub(crate) expression: Expression,
530}
531
532impl Model {
533    /// Get the list of parameter names in the order they appear in the [`Model`]'s [`Manager`] field.
534    pub fn parameters(&self) -> Vec<String> {
535        self.manager.parameters()
536    }
537    /// Create an [`Evaluator`] which can compute the result of the internal [`Expression`] built on
538    /// registered [`Amplitude`]s over the given [`Dataset`]. This method precomputes any relevant
539    /// information over the [`Event`]s in the [`Dataset`].
540    pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
541        let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
542        loaded_resources.write().reserve_cache(dataset.n_events());
543        for amplitude in &self.manager.amplitudes {
544            amplitude.precompute_all(dataset, &mut loaded_resources.write());
545        }
546        Evaluator {
547            amplitudes: self.manager.amplitudes.clone(),
548            resources: loaded_resources.clone(),
549            dataset: dataset.clone(),
550            expression: self.expression.clone(),
551        }
552    }
553}
554
555/// A structure which can be used to evaluate the stored [`Expression`] built on registered
556/// [`Amplitude`]s. This contains a [`Resources`] struct which already contains cached values for
557/// precomputed [`Amplitude`]s and any relevant free parameters and constants.
558#[derive(Clone)]
559pub struct Evaluator {
560    /// A list of [`Amplitude`]s which were registered with the [`Manager`] used to create the
561    /// internal [`Expression`]. This includes but is not limited to those which are actually used
562    /// in the [`Expression`].
563    pub amplitudes: Vec<Box<dyn Amplitude>>,
564    /// The internal [`Resources`] where precalculated values are stored
565    pub resources: Arc<RwLock<Resources>>,
566    /// The internal [`Dataset`]
567    pub dataset: Arc<Dataset>,
568    /// The internal [`Expression`]
569    pub expression: Expression,
570}
571
572impl Evaluator {
573    /// Get the list of parameter names in the order they appear in the [`Evaluator::evaluate`]
574    /// method.
575    pub fn parameters(&self) -> Vec<String> {
576        self.resources.read().parameters.iter().cloned().collect()
577    }
578    /// Activate an [`Amplitude`] by name.
579    pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
580        self.resources.write().activate(name)
581    }
582    /// Activate several [`Amplitude`]s by name.
583    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
584        self.resources.write().activate_many(names)
585    }
586    /// Activate all registered [`Amplitude`]s.
587    pub fn activate_all(&self) {
588        self.resources.write().activate_all();
589    }
590    /// Dectivate an [`Amplitude`] by name.
591    pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
592        self.resources.write().deactivate(name)
593    }
594    /// Deactivate several [`Amplitude`]s by name.
595    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
596        self.resources.write().deactivate_many(names)
597    }
598    /// Deactivate all registered [`Amplitude`]s.
599    pub fn deactivate_all(&self) {
600        self.resources.write().deactivate_all();
601    }
602    /// Isolate an [`Amplitude`] by name (deactivate the rest).
603    pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
604        self.resources.write().isolate(name)
605    }
606    /// Isolate several [`Amplitude`]s by name (deactivate the rest).
607    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
608        self.resources.write().isolate_many(names)
609    }
610
611    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
612    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
613    ///
614    /// # Notes
615    ///
616    /// This method is not intended to be called in analyses but rather in writing methods
617    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
618    pub fn evaluate_local(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
619        let resources = self.resources.read();
620        let parameters = Parameters::new(parameters, &resources.constants);
621        #[cfg(feature = "rayon")]
622        {
623            let amplitude_values_vec: Vec<AmplitudeValues> = self
624                .dataset
625                .events
626                .par_iter()
627                .zip(resources.caches.par_iter())
628                .map(|(event, cache)| {
629                    AmplitudeValues(
630                        self.amplitudes
631                            .iter()
632                            .zip(resources.active.iter())
633                            .map(|(amp, active)| {
634                                if *active {
635                                    amp.compute(&parameters, event, cache)
636                                } else {
637                                    Complex::new(0.0, 0.0)
638                                }
639                            })
640                            .collect(),
641                    )
642                })
643                .collect();
644            amplitude_values_vec
645                .par_iter()
646                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
647                .collect()
648        }
649        #[cfg(not(feature = "rayon"))]
650        {
651            let amplitude_values_vec: Vec<AmplitudeValues> = self
652                .dataset
653                .events
654                .iter()
655                .zip(resources.caches.iter())
656                .map(|(event, cache)| {
657                    AmplitudeValues(
658                        self.amplitudes
659                            .iter()
660                            .zip(resources.active.iter())
661                            .map(|(amp, active)| {
662                                if *active {
663                                    amp.compute(&parameters, event, cache)
664                                } else {
665                                    Complex::new(0.0, 0.0)
666                                }
667                            })
668                            .collect(),
669                    )
670                })
671                .collect();
672            amplitude_values_vec
673                .iter()
674                .map(|amplitude_values| self.expression.evaluate(amplitude_values))
675                .collect()
676        }
677    }
678
679    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
680    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
681    ///
682    /// # Notes
683    ///
684    /// This method is not intended to be called in analyses but rather in writing methods
685    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
686    #[cfg(feature = "mpi")]
687    fn evaluate_mpi(
688        &self,
689        parameters: &[Float],
690        world: &SimpleCommunicator,
691    ) -> Vec<Complex<Float>> {
692        let local_evaluation = self.evaluate_local(parameters);
693        let n_events = self.dataset.n_events();
694        let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; n_events];
695        let (counts, displs) = world.get_counts_displs(n_events);
696        {
697            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
698            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
699        }
700        buffer
701    }
702
703    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
704    /// [`Evaluator`] with the given values for free parameters.
705    pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
706        #[cfg(feature = "mpi")]
707        {
708            if let Some(world) = crate::mpi::get_world() {
709                return self.evaluate_mpi(parameters, &world);
710            }
711        }
712        self.evaluate_local(parameters)
713    }
714
715    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
716    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
717    ///
718    /// # Notes
719    ///
720    /// This method is not intended to be called in analyses but rather in writing methods
721    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
722    pub fn evaluate_gradient_local(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
723        let resources = self.resources.read();
724        let parameters = Parameters::new(parameters, &resources.constants);
725        #[cfg(feature = "rayon")]
726        {
727            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
728                .dataset
729                .events
730                .par_iter()
731                .zip(resources.caches.par_iter())
732                .map(|(event, cache)| {
733                    let mut gradient_values =
734                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
735                    self.amplitudes
736                        .iter()
737                        .zip(resources.active.iter())
738                        .zip(gradient_values.iter_mut())
739                        .for_each(|((amp, active), grad)| {
740                            if *active {
741                                amp.compute_gradient(&parameters, event, cache, grad)
742                            }
743                        });
744                    (
745                        AmplitudeValues(
746                            self.amplitudes
747                                .iter()
748                                .zip(resources.active.iter())
749                                .map(|(amp, active)| {
750                                    if *active {
751                                        amp.compute(&parameters, event, cache)
752                                    } else {
753                                        Complex::new(0.0, 0.0)
754                                    }
755                                })
756                                .collect(),
757                        ),
758                        GradientValues(parameters.len(), gradient_values),
759                    )
760                })
761                .collect();
762            amplitude_values_and_gradient_vec
763                .par_iter()
764                .map(|(amplitude_values, gradient_values)| {
765                    self.expression
766                        .evaluate_gradient(amplitude_values, gradient_values)
767                })
768                .collect()
769        }
770        #[cfg(not(feature = "rayon"))]
771        {
772            let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
773                .dataset
774                .events
775                .iter()
776                .zip(resources.caches.iter())
777                .map(|(event, cache)| {
778                    let mut gradient_values =
779                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
780                    self.amplitudes
781                        .iter()
782                        .zip(resources.active.iter())
783                        .zip(gradient_values.iter_mut())
784                        .for_each(|((amp, active), grad)| {
785                            if *active {
786                                amp.compute_gradient(&parameters, event, cache, grad)
787                            }
788                        });
789                    (
790                        AmplitudeValues(
791                            self.amplitudes
792                                .iter()
793                                .zip(resources.active.iter())
794                                .map(|(amp, active)| {
795                                    if *active {
796                                        amp.compute(&parameters, event, cache)
797                                    } else {
798                                        Complex::new(0.0, 0.0)
799                                    }
800                                })
801                                .collect(),
802                        ),
803                        GradientValues(parameters.len(), gradient_values),
804                    )
805                })
806                .collect();
807
808            amplitude_values_and_gradient_vec
809                .iter()
810                .map(|(amplitude_values, gradient_values)| {
811                    self.expression
812                        .evaluate_gradient(amplitude_values, gradient_values)
813                })
814                .collect()
815        }
816    }
817
818    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
819    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
820    ///
821    /// # Notes
822    ///
823    /// This method is not intended to be called in analyses but rather in writing methods
824    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
825    #[cfg(feature = "mpi")]
826    fn evaluate_gradient_mpi(
827        &self,
828        parameters: &[Float],
829        world: &SimpleCommunicator,
830    ) -> Vec<DVector<Complex<Float>>> {
831        let flattened_local_evaluation = self
832            .evaluate_gradient_local(parameters)
833            .iter()
834            .flat_map(|g| g.data.as_vec().to_vec())
835            .collect::<Vec<Complex<Float>>>();
836        let n_events = self.dataset.n_events();
837        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
838        let mut flattened_result_buffer = vec![Complex::ZERO; n_events * parameters.len()];
839        let mut partitioned_flattened_result_buffer =
840            PartitionMut::new(&mut flattened_result_buffer, counts, displs);
841        world.all_gather_varcount_into(
842            &flattened_local_evaluation,
843            &mut partitioned_flattened_result_buffer,
844        );
845        flattened_result_buffer
846            .chunks(parameters.len())
847            .map(DVector::from_row_slice)
848            .collect()
849    }
850
851    /// Evaluate gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
852    /// [`Evaluator`] with the given values for free parameters.
853    pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
854        #[cfg(feature = "mpi")]
855        {
856            if let Some(world) = crate::mpi::get_world() {
857                return self.evaluate_gradient_mpi(parameters, &world);
858            }
859        }
860        self.evaluate_gradient_local(parameters)
861    }
862}
863
864#[cfg(test)]
865mod tests {
866    use crate::data::{test_dataset, test_event};
867
868    use super::*;
869    use crate::{
870        data::Event,
871        resources::{Cache, ParameterID, Parameters, Resources},
872        Complex, DVector, Float, LadduError,
873    };
874    use approx::assert_relative_eq;
875    use serde::{Deserialize, Serialize};
876
877    #[derive(Clone, Serialize, Deserialize)]
878    pub struct ComplexScalar {
879        name: String,
880        re: ParameterLike,
881        pid_re: ParameterID,
882        im: ParameterLike,
883        pid_im: ParameterID,
884    }
885
886    impl ComplexScalar {
887        pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
888            Self {
889                name: name.to_string(),
890                re,
891                pid_re: Default::default(),
892                im,
893                pid_im: Default::default(),
894            }
895            .into()
896        }
897    }
898
899    #[typetag::serde]
900    impl Amplitude for ComplexScalar {
901        fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
902            self.pid_re = resources.register_parameter(&self.re);
903            self.pid_im = resources.register_parameter(&self.im);
904            resources.register_amplitude(&self.name)
905        }
906
907        fn compute(
908            &self,
909            parameters: &Parameters,
910            _event: &Event,
911            _cache: &Cache,
912        ) -> Complex<Float> {
913            Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
914        }
915
916        fn compute_gradient(
917            &self,
918            _parameters: &Parameters,
919            _event: &Event,
920            _cache: &Cache,
921            gradient: &mut DVector<Complex<Float>>,
922        ) {
923            if let ParameterID::Parameter(ind) = self.pid_re {
924                gradient[ind] = Complex::ONE;
925            }
926            if let ParameterID::Parameter(ind) = self.pid_im {
927                gradient[ind] = Complex::I;
928            }
929        }
930    }
931
932    #[test]
933    fn test_constant_amplitude() {
934        let mut manager = Manager::default();
935        let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
936        let aid = manager.register(amp).unwrap();
937        let dataset = Arc::new(Dataset {
938            events: vec![Arc::new(test_event())],
939        });
940        let expr = Expression::Amp(aid);
941        let model = manager.model(&expr);
942        let evaluator = model.load(&dataset);
943        let result = evaluator.evaluate(&[]);
944        assert_eq!(result[0], Complex::new(2.0, 3.0));
945    }
946
947    #[test]
948    fn test_parametric_amplitude() {
949        let mut manager = Manager::default();
950        let amp = ComplexScalar::new(
951            "parametric",
952            parameter("test_param_re"),
953            parameter("test_param_im"),
954        );
955        let aid = manager.register(amp).unwrap();
956        let dataset = Arc::new(test_dataset());
957        let expr = Expression::Amp(aid);
958        let model = manager.model(&expr);
959        let evaluator = model.load(&dataset);
960        let result = evaluator.evaluate(&[2.0, 3.0]);
961        assert_eq!(result[0], Complex::new(2.0, 3.0));
962    }
963
964    #[test]
965    fn test_expression_operations() {
966        let mut manager = Manager::default();
967        let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
968        let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
969        let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
970
971        let aid1 = manager.register(amp1).unwrap();
972        let aid2 = manager.register(amp2).unwrap();
973        let aid3 = manager.register(amp3).unwrap();
974
975        let dataset = Arc::new(test_dataset());
976
977        // Test (amp) addition
978        let expr_add = &aid1 + &aid2;
979        let model_add = manager.model(&expr_add);
980        let eval_add = model_add.load(&dataset);
981        let result_add = eval_add.evaluate(&[]);
982        assert_eq!(result_add[0], Complex::new(2.0, 1.0));
983
984        // Test (amp) subtraction
985        let expr_sub = &aid1 - &aid2;
986        let model_sub = manager.model(&expr_sub);
987        let eval_sub = model_sub.load(&dataset);
988        let result_sub = eval_sub.evaluate(&[]);
989        assert_eq!(result_sub[0], Complex::new(2.0, -1.0));
990
991        // Test (amp) multiplication
992        let expr_mul = &aid1 * &aid2;
993        let model_mul = manager.model(&expr_mul);
994        let eval_mul = model_mul.load(&dataset);
995        let result_mul = eval_mul.evaluate(&[]);
996        assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
997
998        // Test (amp) division
999        let expr_div = &aid1 / &aid3;
1000        let model_div = manager.model(&expr_div);
1001        let eval_div = model_div.load(&dataset);
1002        let result_div = eval_div.evaluate(&[]);
1003        assert_eq!(result_div[0], Complex::new(6.0 / 25.0, -8.0 / 25.0));
1004
1005        // Test (amp) neg
1006        let expr_neg = -&aid3;
1007        let model_neg = manager.model(&expr_neg);
1008        let eval_neg = model_neg.load(&dataset);
1009        let result_neg = eval_neg.evaluate(&[]);
1010        assert_eq!(result_neg[0], Complex::new(-3.0, -4.0));
1011
1012        // Test (expr) addition
1013        let expr_add2 = &expr_add + &expr_mul;
1014        let model_add2 = manager.model(&expr_add2);
1015        let eval_add2 = model_add2.load(&dataset);
1016        let result_add2 = eval_add2.evaluate(&[]);
1017        assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
1018
1019        // Test (expr) subtraction
1020        let expr_sub2 = &expr_add - &expr_mul;
1021        let model_sub2 = manager.model(&expr_sub2);
1022        let eval_sub2 = model_sub2.load(&dataset);
1023        let result_sub2 = eval_sub2.evaluate(&[]);
1024        assert_eq!(result_sub2[0], Complex::new(2.0, -1.0));
1025
1026        // Test (expr) multiplication
1027        let expr_mul2 = &expr_add * &expr_mul;
1028        let model_mul2 = manager.model(&expr_mul2);
1029        let eval_mul2 = model_mul2.load(&dataset);
1030        let result_mul2 = eval_mul2.evaluate(&[]);
1031        assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
1032
1033        // Test (expr) division
1034        let expr_div2 = &expr_add / &expr_add2;
1035        let model_div2 = manager.model(&expr_div2);
1036        let eval_div2 = model_div2.load(&dataset);
1037        let result_div2 = eval_div2.evaluate(&[]);
1038        assert_eq!(result_div2[0], Complex::new(7.0 / 13.0, -4.0 / 13.0));
1039
1040        // Test (expr) neg
1041        let expr_neg2 = -&expr_mul2;
1042        let model_neg2 = manager.model(&expr_neg2);
1043        let eval_neg2 = model_neg2.load(&dataset);
1044        let result_neg2 = eval_neg2.evaluate(&[]);
1045        assert_eq!(result_neg2[0], Complex::new(2.0, -4.0));
1046
1047        // Test (amp) real
1048        let expr_real = aid3.real();
1049        let model_real = manager.model(&expr_real);
1050        let eval_real = model_real.load(&dataset);
1051        let result_real = eval_real.evaluate(&[]);
1052        assert_eq!(result_real[0], Complex::new(3.0, 0.0));
1053
1054        // Test (expr) real
1055        let expr_mul2_real = expr_mul2.real();
1056        let model_mul2_real = manager.model(&expr_mul2_real);
1057        let eval_mul2_real = model_mul2_real.load(&dataset);
1058        let result_mul2_real = eval_mul2_real.evaluate(&[]);
1059        assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
1060
1061        // Test (amp) imag
1062        let expr_imag = aid3.imag();
1063        let model_imag = manager.model(&expr_imag);
1064        let eval_imag = model_imag.load(&dataset);
1065        let result_imag = eval_imag.evaluate(&[]);
1066        assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
1067
1068        // Test (expr) imag
1069        let expr_mul2_imag = expr_mul2.imag();
1070        let model_mul2_imag = manager.model(&expr_mul2_imag);
1071        let eval_mul2_imag = model_mul2_imag.load(&dataset);
1072        let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
1073        assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
1074
1075        // Test (amp) conj
1076        let expr_conj = aid3.conj();
1077        let model_conj = manager.model(&expr_conj);
1078        let eval_conj = model_conj.load(&dataset);
1079        let result_conj = eval_conj.evaluate(&[]);
1080        assert_eq!(result_conj[0], Complex::new(3.0, -4.0));
1081
1082        // Test (expr) conj
1083        let expr_mul2_conj = expr_mul2.conj();
1084        let model_mul2_conj = manager.model(&expr_mul2_conj);
1085        let eval_mul2_conj = model_mul2_conj.load(&dataset);
1086        let result_mul2_conj = eval_mul2_conj.evaluate(&[]);
1087        assert_eq!(result_mul2_conj[0], Complex::new(-2.0, -4.0));
1088
1089        // Test (amp) norm_sqr
1090        let expr_norm = aid1.norm_sqr();
1091        let model_norm = manager.model(&expr_norm);
1092        let eval_norm = model_norm.load(&dataset);
1093        let result_norm = eval_norm.evaluate(&[]);
1094        assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
1095
1096        // Test (expr) norm_sqr
1097        let expr_mul2_norm = expr_mul2.norm_sqr();
1098        let model_mul2_norm = manager.model(&expr_mul2_norm);
1099        let eval_mul2_norm = model_mul2_norm.load(&dataset);
1100        let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
1101        assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
1102    }
1103
1104    #[test]
1105    fn test_amplitude_activation() {
1106        let mut manager = Manager::default();
1107        let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
1108        let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
1109
1110        let aid1 = manager.register(amp1).unwrap();
1111        let aid2 = manager.register(amp2).unwrap();
1112
1113        let dataset = Arc::new(test_dataset());
1114        let expr = &aid1 + &aid2;
1115        let model = manager.model(&expr);
1116        let evaluator = model.load(&dataset);
1117
1118        // Test initial state (all active)
1119        let result = evaluator.evaluate(&[]);
1120        assert_eq!(result[0], Complex::new(3.0, 0.0));
1121
1122        // Test deactivation
1123        evaluator.deactivate("const1").unwrap();
1124        let result = evaluator.evaluate(&[]);
1125        assert_eq!(result[0], Complex::new(2.0, 0.0));
1126
1127        // Test isolation
1128        evaluator.isolate("const1").unwrap();
1129        let result = evaluator.evaluate(&[]);
1130        assert_eq!(result[0], Complex::new(1.0, 0.0));
1131
1132        // Test reactivation
1133        evaluator.activate_all();
1134        let result = evaluator.evaluate(&[]);
1135        assert_eq!(result[0], Complex::new(3.0, 0.0));
1136    }
1137
1138    #[test]
1139    fn test_gradient() {
1140        let mut manager = Manager::default();
1141        let amp1 = ComplexScalar::new(
1142            "parametric_1",
1143            parameter("test_param_re_1"),
1144            parameter("test_param_im_1"),
1145        );
1146        let amp2 = ComplexScalar::new(
1147            "parametric_2",
1148            parameter("test_param_re_2"),
1149            parameter("test_param_im_2"),
1150        );
1151
1152        let aid1 = manager.register(amp1).unwrap();
1153        let aid2 = manager.register(amp2).unwrap();
1154        let dataset = Arc::new(test_dataset());
1155        let params = vec![2.0, 3.0, 4.0, 5.0];
1156
1157        let expr = &aid1 + &aid2;
1158        let model = manager.model(&expr);
1159        let evaluator = model.load(&dataset);
1160
1161        let gradient = evaluator.evaluate_gradient(&params);
1162
1163        assert_relative_eq!(gradient[0][0].re, 1.0);
1164        assert_relative_eq!(gradient[0][0].im, 0.0);
1165        assert_relative_eq!(gradient[0][1].re, 0.0);
1166        assert_relative_eq!(gradient[0][1].im, 1.0);
1167        assert_relative_eq!(gradient[0][2].re, 1.0);
1168        assert_relative_eq!(gradient[0][2].im, 0.0);
1169        assert_relative_eq!(gradient[0][3].re, 0.0);
1170        assert_relative_eq!(gradient[0][3].im, 1.0);
1171
1172        let expr = &aid1 - &aid2;
1173        let model = manager.model(&expr);
1174        let evaluator = model.load(&dataset);
1175
1176        let gradient = evaluator.evaluate_gradient(&params);
1177
1178        assert_relative_eq!(gradient[0][0].re, 1.0);
1179        assert_relative_eq!(gradient[0][0].im, 0.0);
1180        assert_relative_eq!(gradient[0][1].re, 0.0);
1181        assert_relative_eq!(gradient[0][1].im, 1.0);
1182        assert_relative_eq!(gradient[0][2].re, -1.0);
1183        assert_relative_eq!(gradient[0][2].im, 0.0);
1184        assert_relative_eq!(gradient[0][3].re, 0.0);
1185        assert_relative_eq!(gradient[0][3].im, -1.0);
1186
1187        let expr = &aid1 * &aid2;
1188        let model = manager.model(&expr);
1189        let evaluator = model.load(&dataset);
1190
1191        let gradient = evaluator.evaluate_gradient(&params);
1192
1193        assert_relative_eq!(gradient[0][0].re, 4.0);
1194        assert_relative_eq!(gradient[0][0].im, 5.0);
1195        assert_relative_eq!(gradient[0][1].re, -5.0);
1196        assert_relative_eq!(gradient[0][1].im, 4.0);
1197        assert_relative_eq!(gradient[0][2].re, 2.0);
1198        assert_relative_eq!(gradient[0][2].im, 3.0);
1199        assert_relative_eq!(gradient[0][3].re, -3.0);
1200        assert_relative_eq!(gradient[0][3].im, 2.0);
1201
1202        let expr = &aid1 / &aid2;
1203        let model = manager.model(&expr);
1204        let evaluator = model.load(&dataset);
1205
1206        let gradient = evaluator.evaluate_gradient(&params);
1207
1208        assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
1209        assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
1210        assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
1211        assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
1212        assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
1213        assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
1214        assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
1215        assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
1216
1217        let expr = -(&aid1 * &aid2);
1218        let model = manager.model(&expr);
1219        let evaluator = model.load(&dataset);
1220
1221        let gradient = evaluator.evaluate_gradient(&params);
1222
1223        assert_relative_eq!(gradient[0][0].re, -4.0);
1224        assert_relative_eq!(gradient[0][0].im, -5.0);
1225        assert_relative_eq!(gradient[0][1].re, 5.0);
1226        assert_relative_eq!(gradient[0][1].im, -4.0);
1227        assert_relative_eq!(gradient[0][2].re, -2.0);
1228        assert_relative_eq!(gradient[0][2].im, -3.0);
1229        assert_relative_eq!(gradient[0][3].re, 3.0);
1230        assert_relative_eq!(gradient[0][3].im, -2.0);
1231
1232        let expr = (&aid1 * &aid2).real();
1233        let model = manager.model(&expr);
1234        let evaluator = model.load(&dataset);
1235
1236        let gradient = evaluator.evaluate_gradient(&params);
1237
1238        assert_relative_eq!(gradient[0][0].re, 4.0);
1239        assert_relative_eq!(gradient[0][0].im, 0.0);
1240        assert_relative_eq!(gradient[0][1].re, -5.0);
1241        assert_relative_eq!(gradient[0][1].im, 0.0);
1242        assert_relative_eq!(gradient[0][2].re, 2.0);
1243        assert_relative_eq!(gradient[0][2].im, 0.0);
1244        assert_relative_eq!(gradient[0][3].re, -3.0);
1245        assert_relative_eq!(gradient[0][3].im, 0.0);
1246
1247        let expr = (&aid1 * &aid2).imag();
1248        let model = manager.model(&expr);
1249        let evaluator = model.load(&dataset);
1250
1251        let gradient = evaluator.evaluate_gradient(&params);
1252
1253        assert_relative_eq!(gradient[0][0].re, 5.0);
1254        assert_relative_eq!(gradient[0][0].im, 0.0);
1255        assert_relative_eq!(gradient[0][1].re, 4.0);
1256        assert_relative_eq!(gradient[0][1].im, 0.0);
1257        assert_relative_eq!(gradient[0][2].re, 3.0);
1258        assert_relative_eq!(gradient[0][2].im, 0.0);
1259        assert_relative_eq!(gradient[0][3].re, 2.0);
1260        assert_relative_eq!(gradient[0][3].im, 0.0);
1261
1262        let expr = (&aid1 * &aid2).conj();
1263        let model = manager.model(&expr);
1264        let evaluator = model.load(&dataset);
1265
1266        let gradient = evaluator.evaluate_gradient(&params);
1267
1268        assert_relative_eq!(gradient[0][0].re, 4.0);
1269        assert_relative_eq!(gradient[0][0].im, -5.0);
1270        assert_relative_eq!(gradient[0][1].re, -5.0);
1271        assert_relative_eq!(gradient[0][1].im, -4.0);
1272        assert_relative_eq!(gradient[0][2].re, 2.0);
1273        assert_relative_eq!(gradient[0][2].im, -3.0);
1274        assert_relative_eq!(gradient[0][3].re, -3.0);
1275        assert_relative_eq!(gradient[0][3].im, -2.0);
1276
1277        let expr = (&aid1 * &aid2).norm_sqr();
1278        let model = manager.model(&expr);
1279        let evaluator = model.load(&dataset);
1280
1281        let gradient = evaluator.evaluate_gradient(&params);
1282
1283        assert_relative_eq!(gradient[0][0].re, 164.0);
1284        assert_relative_eq!(gradient[0][0].im, 0.0);
1285        assert_relative_eq!(gradient[0][1].re, 246.0);
1286        assert_relative_eq!(gradient[0][1].im, 0.0);
1287        assert_relative_eq!(gradient[0][2].re, 104.0);
1288        assert_relative_eq!(gradient[0][2].im, 0.0);
1289        assert_relative_eq!(gradient[0][3].re, 130.0);
1290        assert_relative_eq!(gradient[0][3].im, 0.0);
1291    }
1292
1293    #[test]
1294    fn test_zeros_and_ones() {
1295        let mut manager = Manager::default();
1296        let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1297        let aid = manager.register(amp).unwrap();
1298        let dataset = Arc::new(test_dataset());
1299        let expr = (aid * Expression::One + Expression::Zero).norm_sqr();
1300        let model = manager.model(&expr);
1301        let evaluator = model.load(&dataset);
1302
1303        let params = vec![2.0];
1304        let value = evaluator.evaluate(&params);
1305        let gradient = evaluator.evaluate_gradient(&params);
1306
1307        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the value should be x^2 + 4
1308        assert_relative_eq!(value[0].re, 8.0);
1309        assert_relative_eq!(value[0].im, 0.0);
1310
1311        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the derivative should be 2x
1312        assert_relative_eq!(gradient[0][0].re, 4.0);
1313        assert_relative_eq!(gradient[0][0].im, 0.0);
1314    }
1315
1316    #[test]
1317    fn test_parameter_registration() {
1318        let mut manager = Manager::default();
1319        let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1320
1321        let aid = manager.register(amp).unwrap();
1322        let parameters = manager.parameters();
1323        let model = manager.model(&aid.into());
1324        let model_parameters = model.parameters();
1325        assert_eq!(parameters.len(), 1);
1326        assert_eq!(parameters[0], "test_param_re");
1327        assert_eq!(model_parameters.len(), 1);
1328        assert_eq!(model_parameters[0], "test_param_re");
1329    }
1330
1331    #[test]
1332    fn test_duplicate_amplitude_registration() {
1333        let mut manager = Manager::default();
1334        let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
1335        let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
1336        manager.register(amp1).unwrap();
1337        assert!(manager.register(amp2).is_err());
1338    }
1339
1340    #[test]
1341    fn test_tree_printing() {
1342        let mut manager = Manager::default();
1343        let amp1 = ComplexScalar::new(
1344            "parametric_1",
1345            parameter("test_param_re_1"),
1346            parameter("test_param_im_1"),
1347        );
1348        let amp2 = ComplexScalar::new(
1349            "parametric_2",
1350            parameter("test_param_re_2"),
1351            parameter("test_param_im_2"),
1352        );
1353        let aid1 = manager.register(amp1).unwrap();
1354        let aid2 = manager.register(amp2).unwrap();
1355        let expr = &aid1.real() + &aid2.conj().imag() + Expression::One * -Expression::Zero
1356            - Expression::Zero / Expression::One
1357            + (&aid1 * &aid2).norm_sqr();
1358        assert_eq!(
1359            expr.to_string(),
1360            "+
1361├─ -
1362│  ├─ +
1363│  │  ├─ +
1364│  │  │  ├─ Re
1365│  │  │  │  └─ parametric_1(id=0)
1366│  │  │  └─ Im
1367│  │  │     └─ *
1368│  │  │        └─ parametric_2(id=1)
1369│  │  └─ ×
1370│  │     ├─ 1
1371│  │     └─ -
1372│  │        └─ 0
1373│  └─ ÷
1374│     ├─ 0
1375│     └─ 1
1376└─ NormSqr
1377   └─ ×
1378      ├─ parametric_1(id=0)
1379      └─ parametric_2(id=1)
1380"
1381        );
1382    }
1383}