laddu_core/
amplitudes.rs

1use std::{
2    fmt::{Debug, Display},
3    sync::{
4        atomic::{AtomicU64, Ordering},
5        Arc,
6    },
7};
8
9use auto_ops::*;
10use dyn_clone::DynClone;
11use nalgebra::{ComplexField, DVector};
12use num::complex::Complex64;
13
14use parking_lot::RwLock;
15#[cfg(feature = "rayon")]
16use rayon::prelude::*;
17use serde::{Deserialize, Serialize};
18
19static AMPLITUDE_INSTANCE_COUNTER: AtomicU64 = AtomicU64::new(0);
20
21fn next_amplitude_id() -> u64 {
22    AMPLITUDE_INSTANCE_COUNTER.fetch_add(1, Ordering::Relaxed)
23}
24
25use crate::{
26    data::{Dataset, DatasetMetadata, EventData},
27    resources::{Cache, ParameterTransform, Parameters, Resources},
28    LadduError, LadduResult, ParameterID, ReadWrite,
29};
30
31#[cfg(feature = "mpi")]
32use crate::mpi::LadduMPI;
33
34#[cfg(feature = "mpi")]
35use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
36
37/// An enum containing either a named free parameter or a constant value.
38#[derive(Clone, Default, Serialize, Deserialize)]
39pub struct Parameter {
40    /// The name of the parameter.
41    pub name: String,
42    /// If `Some`, this parameter is fixed to the given value. If `None`, it is free.
43    pub fixed: Option<f64>,
44}
45
46impl Parameter {
47    /// Create a free (floating) parameter with the given name.
48    pub fn free(name: impl Into<String>) -> Self {
49        Self {
50            name: name.into(),
51            fixed: None,
52        }
53    }
54
55    /// Create a fixed parameter with the given name and value.
56    pub fn fixed(name: impl Into<String>, value: f64) -> Self {
57        Self {
58            name: name.into(),
59            fixed: Some(value),
60        }
61    }
62
63    /// An uninitialized parameter placeholder.
64    pub fn uninit() -> Self {
65        Self {
66            name: String::new(),
67            fixed: None,
68        }
69    }
70
71    /// Is this parameter free?
72    pub fn is_free(&self) -> bool {
73        self.fixed.is_none()
74    }
75
76    /// Is this parameter fixed?
77    pub fn is_fixed(&self) -> bool {
78        self.fixed.is_some()
79    }
80
81    /// Get the parameter name.
82    pub fn name(&self) -> &str {
83        &self.name
84    }
85}
86
87/// Maintains naming used across the crate.
88pub type ParameterLike = Parameter;
89
90/// Shorthand for generating a named free parameter.
91pub fn parameter(name: &str) -> Parameter {
92    Parameter::free(name)
93}
94
95/// Shorthand for generating a fixed parameter with the given name and value.
96pub fn constant(name: &str, value: f64) -> Parameter {
97    Parameter::fixed(name, value)
98}
99
100/// Convenience macro for creating parameters. Usage:
101/// `parameter!(\"name\")` for a free parameter, or `parameter!(\"name\", 1.0)` for a fixed one.
102#[macro_export]
103macro_rules! parameter {
104    ($name:expr) => {
105        $crate::amplitudes::Parameter::free($name)
106    };
107    ($name:expr, $value:expr) => {
108        $crate::amplitudes::Parameter::fixed($name, $value)
109    };
110}
111
112/// This is the only required trait for writing new amplitude-like structures for this
113/// crate. Users need only implement the [`register`](Amplitude::register)
114/// method to register parameters, cached values, and the amplitude itself with an input
115/// [`Resources`] struct and the [`compute`](Amplitude::compute) method to actually carry
116/// out the calculation. [`Amplitude`]-implementors are required to implement [`Clone`] and can
117/// optionally implement a [`precompute`](Amplitude::precompute) method to calculate and
118/// cache values which do not depend on free parameters.
119#[typetag::serde(tag = "type")]
120pub trait Amplitude: DynClone + Send + Sync {
121    /// This method should be used to tell the [`Resources`] manager about all of
122    /// the free parameters and cached values used by this [`Amplitude`]. It should end by
123    /// returning an [`AmplitudeID`], which can be obtained from the
124    /// [`Resources::register_amplitude`] method.
125    ///
126    /// [`register`](Amplitude::register) is invoked once when an amplitude is first added to a
127    /// [`Manager`]. Use it to allocate parameter/cache state within [`Resources`] without assuming
128    /// any dataset context.
129    fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID>;
130    /// Bind this [`Amplitude`] to a concrete [`Dataset`] by using the provided metadata to wire up
131    /// [`Variable`](crate::utils::variables::Variable)s or other dataset-specific state. This will
132    /// be invoked when a [`Model`] is loaded with data, after [`register`](Amplitude::register)
133    /// has already succeeded. The default implementation is a no-op for amplitudes that do not
134    /// depend on metadata.
135    fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
136        Ok(())
137    }
138    /// This method can be used to do some critical calculations ahead of time and
139    /// store them in a [`Cache`]. These values can only depend on the data in an [`EventData`],
140    /// not on any free parameters in the fit. This method is opt-in since it is not required
141    /// to make a functioning [`Amplitude`].
142    #[allow(unused_variables)]
143    fn precompute(&self, event: &EventData, cache: &mut Cache) {}
144    /// Evaluates [`Amplitude::precompute`] over ever [`EventData`] in a [`Dataset`].
145    #[cfg(feature = "rayon")]
146    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
147        dataset
148            .events
149            .par_iter()
150            .zip(resources.caches.par_iter_mut())
151            .for_each(|(event, cache)| {
152                self.precompute(event, cache);
153            })
154    }
155    /// Evaluates [`Amplitude::precompute`] over ever [`EventData`] in a [`Dataset`].
156    #[cfg(not(feature = "rayon"))]
157    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
158        dataset
159            .events
160            .iter()
161            .zip(resources.caches.iter_mut())
162            .for_each(|(event, cache)| self.precompute(event, cache))
163    }
164    /// This method constitutes the main machinery of an [`Amplitude`], returning the actual
165    /// calculated value for a particular [`EventData`] and set of [`Parameters`]. See those
166    /// structs, as well as [`Cache`], for documentation on their available methods. For the
167    /// most part, [`EventData`]s can be interacted with via
168    /// [`Variable`](crate::utils::variables::Variable)s, while [`Parameters`] and the
169    /// [`Cache`] are more like key-value storage accessed by
170    /// [`ParameterID`]s and several different types of cache
171    /// IDs.
172    fn compute(&self, parameters: &Parameters, event: &EventData, cache: &Cache) -> Complex64;
173
174    /// This method yields the gradient of a particular [`Amplitude`] at a point specified
175    /// by a particular [`EventData`] and set of [`Parameters`]. See those structs, as well as
176    /// [`Cache`], for documentation on their available methods. For the most part,
177    /// [`EventData`]s can be interacted with via [`Variable`](crate::utils::variables::Variable)s,
178    /// while [`Parameters`] and the [`Cache`] are more like key-value storage accessed by
179    /// [`ParameterID`]s and several different types of cache
180    /// IDs. If the analytic version of the gradient is known, this method can be overwritten to
181    /// improve performance for some derivative-using methods of minimization. The default
182    /// implementation calculates a central finite difference across all parameters, regardless of
183    /// whether or not they are used in the [`Amplitude`].
184    ///
185    /// In the future, it may be possible to automatically implement this with the indices of
186    /// registered free parameters, but until then, the [`Amplitude::central_difference_with_indices`]
187    /// method can be used to conveniently only calculate central differences for the parameters
188    /// which are used by the [`Amplitude`].
189    fn compute_gradient(
190        &self,
191        parameters: &Parameters,
192        event: &EventData,
193        cache: &Cache,
194        gradient: &mut DVector<Complex64>,
195    ) {
196        self.central_difference_with_indices(
197            &Vec::from_iter(0..parameters.len()),
198            parameters,
199            event,
200            cache,
201            gradient,
202        )
203    }
204
205    /// A helper function to implement a central difference only on indices which correspond to
206    /// free parameters in the [`Amplitude`]. For example, if an [`Amplitude`] contains free
207    /// parameters registered to indices 1, 3, and 5 of the its internal parameters array, then
208    /// running this with those indices will compute a central finite difference derivative for
209    /// those coordinates only, since the rest can be safely assumed to be zero.
210    fn central_difference_with_indices(
211        &self,
212        indices: &[usize],
213        parameters: &Parameters,
214        event: &EventData,
215        cache: &Cache,
216        gradient: &mut DVector<Complex64>,
217    ) {
218        let x = parameters.parameters.to_owned();
219        let constants = parameters.constants.to_owned();
220        let h: DVector<f64> = x
221            .iter()
222            .map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
223            .collect::<Vec<_>>()
224            .into();
225        for i in indices {
226            let mut x_plus = x.clone();
227            let mut x_minus = x.clone();
228            x_plus[*i] += h[*i];
229            x_minus[*i] -= h[*i];
230            let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
231            let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
232            gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
233        }
234    }
235
236    /// Convenience helper to wrap an amplitude into an [`Expression`].
237    ///
238    /// This allows amplitude constructors to return `LadduResult<Expression>` without duplicating
239    /// boxing/registration boilerplate.
240    fn into_expression(self) -> LadduResult<Expression>
241    where
242        Self: Sized + 'static,
243    {
244        Expression::from_amplitude(Box::new(self))
245    }
246}
247
248/// Utility function to calculate a central finite difference gradient.
249pub fn central_difference<F: Fn(&[f64]) -> f64>(parameters: &[f64], func: F) -> DVector<f64> {
250    let mut gradient = DVector::zeros(parameters.len());
251    let x = parameters.to_owned();
252    let h: DVector<f64> = x
253        .iter()
254        .map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
255        .collect::<Vec<_>>()
256        .into();
257    for i in 0..parameters.len() {
258        let mut x_plus = x.clone();
259        let mut x_minus = x.clone();
260        x_plus[i] += h[i];
261        x_minus[i] -= h[i];
262        let f_plus = func(&x_plus);
263        let f_minus = func(&x_minus);
264        gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
265    }
266    gradient
267}
268
269dyn_clone::clone_trait_object!(Amplitude);
270
271/// A helper struct that contains the value of each amplitude for a particular event
272#[derive(Debug)]
273pub struct AmplitudeValues(pub Vec<Complex64>);
274
275/// A helper struct that contains the gradient of each amplitude for a particular event
276#[derive(Debug)]
277pub struct GradientValues(pub usize, pub Vec<DVector<Complex64>>);
278
279/// A tag which refers to a registered [`Amplitude`]. This is the base object which can be used to
280/// build [`Expression`]s and should be obtained from the [`Resources::register`] method.
281#[derive(Clone, Default, Debug, Serialize, Deserialize)]
282pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
283
284impl Display for AmplitudeID {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        write!(f, "{}(id={})", self.0, self.1)
287    }
288}
289
290/// A holder struct that owns both an expression tree and the registered amplitudes.
291#[allow(missing_docs)]
292#[derive(Clone, Serialize, Deserialize)]
293pub struct Expression {
294    registry: ExpressionRegistry,
295    tree: ExpressionNode,
296}
297
298impl ReadWrite for Expression {
299    fn create_null() -> Self {
300        Self {
301            registry: ExpressionRegistry::default(),
302            tree: ExpressionNode::default(),
303        }
304    }
305}
306
307#[derive(Clone, Serialize, Deserialize)]
308#[allow(missing_docs)]
309#[derive(Default)]
310pub struct ExpressionRegistry {
311    amplitudes: Vec<Box<dyn Amplitude>>,
312    amplitude_names: Vec<String>,
313    amplitude_ids: Vec<u64>,
314    resources: Resources,
315}
316
317impl ExpressionRegistry {
318    fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
319        let mut resources = Resources::default();
320        let aid = amplitude.register(&mut resources)?;
321        let amp_id = next_amplitude_id();
322        Ok(Self {
323            amplitudes: vec![amplitude],
324            amplitude_names: vec![aid.0],
325            amplitude_ids: vec![amp_id],
326            resources,
327        })
328    }
329
330    fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
331        let mut resources = Resources::default();
332        let mut amplitudes = Vec::new();
333        let mut amplitude_names = Vec::new();
334        let mut amplitude_ids = Vec::new();
335        let mut name_to_index = std::collections::HashMap::new();
336
337        let mut left_map = Vec::with_capacity(self.amplitudes.len());
338        for ((amp, name), amp_id) in self
339            .amplitudes
340            .iter()
341            .zip(&self.amplitude_names)
342            .zip(&self.amplitude_ids)
343        {
344            let mut cloned_amp = dyn_clone::clone_box(&**amp);
345            let aid = cloned_amp.register(&mut resources)?;
346            amplitudes.push(cloned_amp);
347            amplitude_names.push(name.clone());
348            amplitude_ids.push(*amp_id);
349            name_to_index.insert(name.clone(), aid.1);
350            left_map.push(aid.1);
351        }
352
353        let mut right_map = Vec::with_capacity(other.amplitudes.len());
354        for ((amp, name), amp_id) in other
355            .amplitudes
356            .iter()
357            .zip(&other.amplitude_names)
358            .zip(&other.amplitude_ids)
359        {
360            if let Some(existing) = name_to_index.get(name) {
361                let existing_amp_id = amplitude_ids[*existing];
362                if existing_amp_id != *amp_id {
363                    return Err(LadduError::Custom(format!(
364                        "Amplitude name \"{name}\" refers to different underlying amplitudes; rename to avoid conflicts"
365                    )));
366                }
367                right_map.push(*existing);
368                continue;
369            }
370            let mut cloned_amp = dyn_clone::clone_box(&**amp);
371            let aid = cloned_amp.register(&mut resources)?;
372            amplitudes.push(cloned_amp);
373            amplitude_names.push(name.clone());
374            amplitude_ids.push(*amp_id);
375            name_to_index.insert(name.clone(), aid.1);
376            right_map.push(aid.1);
377        }
378
379        Ok((
380            Self {
381                amplitudes,
382                amplitude_names,
383                amplitude_ids,
384                resources,
385            },
386            left_map,
387            right_map,
388        ))
389    }
390
391    fn rebuild_with_transform(&self, transform: ParameterTransform) -> LadduResult<Self> {
392        let mut resources = Resources::with_transform(transform);
393        let mut amplitudes = Vec::new();
394        let mut amplitude_names = Vec::new();
395        let mut amplitude_ids = Vec::new();
396        for ((amp, name), amp_id) in self
397            .amplitudes
398            .iter()
399            .zip(&self.amplitude_names)
400            .zip(&self.amplitude_ids)
401        {
402            let mut cloned_amp = dyn_clone::clone_box(&**amp);
403            let aid = cloned_amp.register(&mut resources)?;
404            if aid.0 != *name {
405                return Err(LadduError::ParameterConflict {
406                    name: aid.0,
407                    reason: "amplitude renamed during rebuild".to_string(),
408                });
409            }
410            amplitudes.push(cloned_amp);
411            amplitude_names.push(name.clone());
412            amplitude_ids.push(*amp_id);
413        }
414        Ok(Self {
415            amplitudes,
416            amplitude_names,
417            amplitude_ids,
418            resources,
419        })
420    }
421}
422
423/// Expression tree used by [`Expression`].
424#[allow(missing_docs)]
425#[derive(Clone, Serialize, Deserialize, Default, Debug)]
426pub enum ExpressionNode {
427    #[default]
428    /// A expression equal to zero.
429    Zero,
430    /// A expression equal to one.
431    One,
432    /// A registered [`Amplitude`] referenced by index.
433    Amp(usize),
434    /// The sum of two [`ExpressionNode`]s.
435    Add(Box<ExpressionNode>, Box<ExpressionNode>),
436    /// The difference of two [`ExpressionNode`]s.
437    Sub(Box<ExpressionNode>, Box<ExpressionNode>),
438    /// The product of two [`ExpressionNode`]s.
439    Mul(Box<ExpressionNode>, Box<ExpressionNode>),
440    /// The division of two [`ExpressionNode`]s.
441    Div(Box<ExpressionNode>, Box<ExpressionNode>),
442    /// The additive inverse of an [`ExpressionNode`].
443    Neg(Box<ExpressionNode>),
444    /// The real part of an [`ExpressionNode`].
445    Real(Box<ExpressionNode>),
446    /// The imaginary part of an [`ExpressionNode`].
447    Imag(Box<ExpressionNode>),
448    /// The complex conjugate of an [`ExpressionNode`].
449    Conj(Box<ExpressionNode>),
450    /// The absolute square of an [`ExpressionNode`].
451    NormSqr(Box<ExpressionNode>),
452}
453
454impl ExpressionNode {
455    fn remap(&self, mapping: &[usize]) -> Self {
456        match self {
457            Self::Amp(idx) => Self::Amp(mapping[*idx]),
458            Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
459            Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
460            Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
461            Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
462            Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
463            Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
464            Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
465            Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
466            Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
467            Self::Zero => Self::Zero,
468            Self::One => Self::One,
469        }
470    }
471
472    /// Evaluate an [`ExpressionNode`] over a single event using calculated [`AmplitudeValues`]
473    ///
474    /// This method parses the underlying [`ExpressionNode`] but doesn't actually calculate the values
475    /// from the [`Amplitude`]s themselves.
476    pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
477        match self {
478            ExpressionNode::Amp(idx) => amplitude_values[*idx],
479            ExpressionNode::Add(a, b) => {
480                a.evaluate(amplitude_values) + b.evaluate(amplitude_values)
481            }
482            ExpressionNode::Sub(a, b) => {
483                a.evaluate(amplitude_values) - b.evaluate(amplitude_values)
484            }
485            ExpressionNode::Mul(a, b) => {
486                a.evaluate(amplitude_values) * b.evaluate(amplitude_values)
487            }
488            ExpressionNode::Div(a, b) => {
489                a.evaluate(amplitude_values) / b.evaluate(amplitude_values)
490            }
491            ExpressionNode::Neg(a) => -a.evaluate(amplitude_values),
492            ExpressionNode::Real(a) => Complex64::new(a.evaluate(amplitude_values).re, 0.0),
493            ExpressionNode::Imag(a) => Complex64::new(a.evaluate(amplitude_values).im, 0.0),
494            ExpressionNode::Conj(a) => a.evaluate(amplitude_values).conj(),
495            ExpressionNode::NormSqr(a) => {
496                let value = a.evaluate(amplitude_values);
497                Complex64::new(value.norm_sqr(), 0.0)
498            }
499            ExpressionNode::Zero => Complex64::ZERO,
500            ExpressionNode::One => Complex64::ONE,
501        }
502    }
503
504    /// Evaluate the gradient of an [`ExpressionNode`] over a single event using calculated [`AmplitudeValues`]
505    ///
506    /// This method parses the underlying [`ExpressionNode`] but doesn't actually calculate the
507    /// gradient from the [`Amplitude`]s themselves.
508    pub fn evaluate_gradient(
509        &self,
510        amplitude_values: &[Complex64],
511        gradient_values: &[DVector<Complex64>],
512    ) -> DVector<Complex64> {
513        match self {
514            ExpressionNode::Amp(idx) => gradient_values[*idx].clone(),
515            ExpressionNode::Add(a, b) => {
516                a.evaluate_gradient(amplitude_values, gradient_values)
517                    + b.evaluate_gradient(amplitude_values, gradient_values)
518            }
519            ExpressionNode::Sub(a, b) => {
520                a.evaluate_gradient(amplitude_values, gradient_values)
521                    - b.evaluate_gradient(amplitude_values, gradient_values)
522            }
523            ExpressionNode::Mul(a, b) => {
524                let f_a = a.evaluate(amplitude_values);
525                let f_b = b.evaluate(amplitude_values);
526                b.evaluate_gradient(amplitude_values, gradient_values)
527                    .map(|g| g * f_a)
528                    + a.evaluate_gradient(amplitude_values, gradient_values)
529                        .map(|g| g * f_b)
530            }
531            ExpressionNode::Div(a, b) => {
532                let f_a = a.evaluate(amplitude_values);
533                let f_b = b.evaluate(amplitude_values);
534                (a.evaluate_gradient(amplitude_values, gradient_values)
535                    .map(|g| g * f_b)
536                    - b.evaluate_gradient(amplitude_values, gradient_values)
537                        .map(|g| g * f_a))
538                    / (f_b * f_b)
539            }
540            ExpressionNode::Neg(a) => -a.evaluate_gradient(amplitude_values, gradient_values),
541            ExpressionNode::Real(a) => a
542                .evaluate_gradient(amplitude_values, gradient_values)
543                .map(|g| Complex64::new(g.re, 0.0)),
544            ExpressionNode::Imag(a) => a
545                .evaluate_gradient(amplitude_values, gradient_values)
546                .map(|g| Complex64::new(g.im, 0.0)),
547            ExpressionNode::Conj(a) => a
548                .evaluate_gradient(amplitude_values, gradient_values)
549                .map(|g| g.conj()),
550            ExpressionNode::NormSqr(a) => {
551                let conj_f_a = a.evaluate(amplitude_values).conjugate();
552                a.evaluate_gradient(amplitude_values, gradient_values)
553                    .map(|g| Complex64::new(2.0 * (g * conj_f_a).re, 0.0))
554            }
555            ExpressionNode::Zero | ExpressionNode::One => {
556                let max_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
557                DVector::zeros(max_dim)
558            }
559        }
560    }
561}
562
563impl Expression {
564    /// Build an [`Expression`] from a single [`Amplitude`].
565    pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
566        let registry = ExpressionRegistry::singleton(amplitude)?;
567        Ok(Self {
568            tree: ExpressionNode::Amp(0),
569            registry,
570        })
571    }
572
573    /// Create an expression representing zero, the additive identity.
574    pub fn zero() -> Self {
575        Self {
576            registry: ExpressionRegistry::default(),
577            tree: ExpressionNode::Zero,
578        }
579    }
580
581    /// Create an expression representing one, the multiplicative identity.
582    pub fn one() -> Self {
583        Self {
584            registry: ExpressionRegistry::default(),
585            tree: ExpressionNode::One,
586        }
587    }
588
589    fn binary_op(
590        a: &Expression,
591        b: &Expression,
592        build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
593    ) -> Expression {
594        let (registry, left_map, right_map) = a
595            .registry
596            .merge(&b.registry)
597            .expect("merging expression registries should not fail");
598        let left_tree = a.tree.remap(&left_map);
599        let right_tree = b.tree.remap(&right_map);
600        Expression {
601            registry,
602            tree: build(Box::new(left_tree), Box::new(right_tree)),
603        }
604    }
605
606    fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
607        Expression {
608            registry: a.registry.clone(),
609            tree: build(Box::new(a.tree.clone())),
610        }
611    }
612
613    /// Get the list of parameter names in the order they appear in the underlying resources.
614    pub fn parameters(&self) -> Vec<String> {
615        self.registry.resources.parameter_names()
616    }
617
618    /// Get the list of free parameter names.
619    pub fn free_parameters(&self) -> Vec<String> {
620        self.registry.resources.free_parameter_names()
621    }
622
623    /// Get the list of fixed parameter names.
624    pub fn fixed_parameters(&self) -> Vec<String> {
625        self.registry.resources.fixed_parameter_names()
626    }
627
628    /// Number of free parameters.
629    pub fn n_free(&self) -> usize {
630        self.registry.resources.n_free_parameters()
631    }
632
633    /// Number of fixed parameters.
634    pub fn n_fixed(&self) -> usize {
635        self.registry.resources.n_fixed_parameters()
636    }
637
638    /// Total number of parameters.
639    pub fn n_parameters(&self) -> usize {
640        self.registry.resources.n_parameters()
641    }
642
643    fn with_transform(&self, transform: ParameterTransform) -> LadduResult<Self> {
644        let merged = self
645            .registry
646            .resources
647            .parameter_overrides
648            .merged(&transform);
649        let registry = self.registry.rebuild_with_transform(merged)?;
650        Ok(Self {
651            registry,
652            tree: self.tree.clone(),
653        })
654    }
655
656    fn assert_parameter_exists(&self, name: &str) -> LadduResult<()> {
657        if self.parameters().iter().any(|p| p == name) {
658            Ok(())
659        } else {
660            Err(LadduError::UnregisteredParameter {
661                name: name.to_string(),
662                reason: "parameter not found".to_string(),
663            })
664        }
665    }
666
667    /// Return a new [`Expression`] with the given parameter fixed to a value.
668    pub fn fix(&self, name: &str, value: f64) -> LadduResult<Self> {
669        self.assert_parameter_exists(name)?;
670        let mut transform = ParameterTransform::default();
671        transform.fixed.insert(name.to_string(), value);
672        self.with_transform(transform)
673    }
674
675    /// Return a new [`Expression`] with the given parameter freed.
676    pub fn free(&self, name: &str) -> LadduResult<Self> {
677        self.assert_parameter_exists(name)?;
678        let mut transform = ParameterTransform::default();
679        transform.freed.insert(name.to_string());
680        self.with_transform(transform)
681    }
682
683    /// Return a new [`Expression`] with a single parameter renamed.
684    pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<Self> {
685        self.assert_parameter_exists(old)?;
686        if old == new {
687            return Ok(self.clone());
688        }
689        if self.parameters().iter().any(|p| p == new) {
690            return Err(LadduError::ParameterConflict {
691                name: new.to_string(),
692                reason: "rename target already exists".to_string(),
693            });
694        }
695        let mut transform = ParameterTransform::default();
696        transform.renames.insert(old.to_string(), new.to_string());
697        self.with_transform(transform)
698    }
699
700    /// Return a new [`Expression`] with several parameters renamed.
701    pub fn rename_parameters(
702        &self,
703        mapping: &std::collections::HashMap<String, String>,
704    ) -> LadduResult<Self> {
705        for old in mapping.keys() {
706            self.assert_parameter_exists(old)?;
707        }
708        let mut final_names: std::collections::HashSet<String> =
709            self.parameters().into_iter().collect();
710        for (old, new) in mapping {
711            if old == new {
712                continue;
713            }
714            final_names.remove(old);
715            if final_names.contains(new) {
716                return Err(LadduError::ParameterConflict {
717                    name: new.clone(),
718                    reason: "rename target already exists".to_string(),
719                });
720            }
721            final_names.insert(new.clone());
722        }
723        let mut transform = ParameterTransform::default();
724        for (old, new) in mapping {
725            transform.renames.insert(old.clone(), new.clone());
726        }
727        self.with_transform(transform)
728    }
729
730    /// Load an [`Expression`] against a dataset, binding amplitudes and reserving caches.
731    pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
732        let mut resources = self.registry.resources.clone();
733        let metadata = dataset.metadata();
734        resources.reserve_cache(dataset.n_events());
735        let mut amplitudes: Vec<Box<dyn Amplitude>> = self
736            .registry
737            .amplitudes
738            .iter()
739            .map(|amp| dyn_clone::clone_box(&**amp))
740            .collect();
741        {
742            for amplitude in amplitudes.iter_mut() {
743                amplitude.bind(metadata)?;
744                amplitude.precompute_all(dataset, &mut resources);
745            }
746        }
747        Ok(Evaluator {
748            amplitudes,
749            resources: Arc::new(RwLock::new(resources)),
750            dataset: dataset.clone(),
751            expression: self.tree.clone(),
752            registry: self.registry.clone(),
753        })
754    }
755
756    /// Takes the real part of the given [`Expression`].
757    pub fn real(&self) -> Self {
758        Self::unary_op(self, ExpressionNode::Real)
759    }
760    /// Takes the imaginary part of the given [`Expression`].
761    pub fn imag(&self) -> Self {
762        Self::unary_op(self, ExpressionNode::Imag)
763    }
764    /// Takes the complex conjugate of the given [`Expression`].
765    pub fn conj(&self) -> Self {
766        Self::unary_op(self, ExpressionNode::Conj)
767    }
768    /// Takes the absolute square of the given [`Expression`].
769    pub fn norm_sqr(&self) -> Self {
770        Self::unary_op(self, ExpressionNode::NormSqr)
771    }
772
773    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
774    fn write_tree(
775        &self,
776        t: &ExpressionNode,
777        f: &mut std::fmt::Formatter<'_>,
778        parent_prefix: &str,
779        immediate_prefix: &str,
780        parent_suffix: &str,
781    ) -> std::fmt::Result {
782        let display_string = match t {
783            ExpressionNode::Amp(idx) => {
784                let name = self
785                    .registry
786                    .amplitude_names
787                    .get(*idx)
788                    .cloned()
789                    .unwrap_or_else(|| "<unregistered>".to_string());
790                format!("{name}(id={idx})")
791            }
792            ExpressionNode::Add(_, _) => "+".to_string(),
793            ExpressionNode::Sub(_, _) => "-".to_string(),
794            ExpressionNode::Mul(_, _) => "×".to_string(),
795            ExpressionNode::Div(_, _) => "÷".to_string(),
796            ExpressionNode::Neg(_) => "-".to_string(),
797            ExpressionNode::Real(_) => "Re".to_string(),
798            ExpressionNode::Imag(_) => "Im".to_string(),
799            ExpressionNode::Conj(_) => "*".to_string(),
800            ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
801            ExpressionNode::Zero => "0".to_string(),
802            ExpressionNode::One => "1".to_string(),
803        };
804        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
805        match t {
806            ExpressionNode::Amp(_) | ExpressionNode::Zero | ExpressionNode::One => {}
807            ExpressionNode::Add(a, b)
808            | ExpressionNode::Sub(a, b)
809            | ExpressionNode::Mul(a, b)
810            | ExpressionNode::Div(a, b) => {
811                let terms = [a, b];
812                let mut it = terms.iter().peekable();
813                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
814                while let Some(child) = it.next() {
815                    match it.peek() {
816                        Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│  "),
817                        None => self.write_tree(child, f, &child_prefix, "└─ ", "   "),
818                    }?;
819                }
820            }
821            ExpressionNode::Neg(a)
822            | ExpressionNode::Real(a)
823            | ExpressionNode::Imag(a)
824            | ExpressionNode::Conj(a)
825            | ExpressionNode::NormSqr(a) => {
826                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
827                self.write_tree(a, f, &child_prefix, "└─ ", "   ")?;
828            }
829        }
830        Ok(())
831    }
832}
833
834impl Debug for Expression {
835    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836        self.write_tree(&self.tree, f, "", "", "")
837    }
838}
839
840impl Display for Expression {
841    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
842        self.write_tree(&self.tree, f, "", "", "")
843    }
844}
845
846#[rustfmt::skip]
847impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
848    Expression::binary_op(a, b, ExpressionNode::Add)
849});
850#[rustfmt::skip]
851impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
852    Expression::binary_op(a, b, ExpressionNode::Sub)
853});
854#[rustfmt::skip]
855impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
856    Expression::binary_op(a, b, ExpressionNode::Mul)
857});
858#[rustfmt::skip]
859impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
860    Expression::binary_op(a, b, ExpressionNode::Div)
861});
862#[rustfmt::skip]
863impl_op_ex!(- |a: &Expression| -> Expression {
864    Expression::unary_op(a, ExpressionNode::Neg)
865});
866
867/// Evaluator for [`Expression`] that mirrors the existing evaluator behavior.
868#[allow(missing_docs)]
869#[derive(Clone)]
870pub struct Evaluator {
871    pub amplitudes: Vec<Box<dyn Amplitude>>,
872    pub resources: Arc<RwLock<Resources>>,
873    pub dataset: Arc<Dataset>,
874    pub expression: ExpressionNode,
875    registry: ExpressionRegistry,
876}
877
878#[allow(missing_docs)]
879impl Evaluator {
880    /// Get the list of parameter names in the order they appear in the [`Evaluator::evaluate`]
881    /// method.
882    pub fn parameters(&self) -> Vec<String> {
883        self.resources.read().parameter_names()
884    }
885
886    /// Get the list of free parameter names.
887    pub fn free_parameters(&self) -> Vec<String> {
888        self.resources.read().free_parameter_names()
889    }
890
891    /// Get the list of fixed parameter names.
892    pub fn fixed_parameters(&self) -> Vec<String> {
893        self.resources.read().fixed_parameter_names()
894    }
895
896    /// Number of free parameters.
897    pub fn n_free(&self) -> usize {
898        self.resources.read().n_free_parameters()
899    }
900
901    /// Number of fixed parameters.
902    pub fn n_fixed(&self) -> usize {
903        self.resources.read().n_fixed_parameters()
904    }
905
906    /// Total number of parameters.
907    pub fn n_parameters(&self) -> usize {
908        self.resources.read().n_parameters()
909    }
910
911    fn as_expression(&self) -> Expression {
912        Expression {
913            registry: self.registry.clone(),
914            tree: self.expression.clone(),
915        }
916    }
917
918    /// Return a new [`Evaluator`] with the given parameter fixed to a value.
919    pub fn fix(&self, name: &str, value: f64) -> LadduResult<Self> {
920        self.as_expression().fix(name, value)?.load(&self.dataset)
921    }
922
923    /// Return a new [`Evaluator`] with the given parameter freed.
924    pub fn free(&self, name: &str) -> LadduResult<Self> {
925        self.as_expression().free(name)?.load(&self.dataset)
926    }
927
928    /// Return a new [`Evaluator`] with a single parameter renamed.
929    pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<Self> {
930        self.as_expression()
931            .rename_parameter(old, new)?
932            .load(&self.dataset)
933    }
934
935    /// Return a new [`Evaluator`] with several parameters renamed.
936    pub fn rename_parameters(
937        &self,
938        mapping: &std::collections::HashMap<String, String>,
939    ) -> LadduResult<Self> {
940        self.as_expression()
941            .rename_parameters(mapping)?
942            .load(&self.dataset)
943    }
944
945    /// Activate an [`Amplitude`] by name, skipping missing entries.
946    pub fn activate<T: AsRef<str>>(&self, name: T) {
947        self.resources.write().activate(name);
948    }
949    /// Activate an [`Amplitude`] by name and return an error if it is missing.
950    pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
951        self.resources.write().activate_strict(name)
952    }
953
954    /// Activate several [`Amplitude`]s by name, skipping missing entries.
955    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
956        self.resources.write().activate_many(names);
957    }
958    /// Activate several [`Amplitude`]s by name and return an error if any are missing.
959    pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
960        self.resources.write().activate_many_strict(names)
961    }
962
963    /// Activate all registered [`Amplitude`]s.
964    pub fn activate_all(&self) {
965        self.resources.write().activate_all();
966    }
967
968    /// Dectivate an [`Amplitude`] by name, skipping missing entries.
969    pub fn deactivate<T: AsRef<str>>(&self, name: T) {
970        self.resources.write().deactivate(name);
971    }
972
973    /// Dectivate an [`Amplitude`] by name and return an error if it is missing.
974    pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
975        self.resources.write().deactivate_strict(name)
976    }
977
978    /// Deactivate several [`Amplitude`]s by name, skipping missing entries.
979    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
980        self.resources.write().deactivate_many(names);
981    }
982    /// Dectivate several [`Amplitude`]s by name and return an error if any are missing.
983    pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
984        self.resources.write().deactivate_many_strict(names)
985    }
986
987    /// Deactivate all registered [`Amplitude`]s.
988    pub fn deactivate_all(&self) {
989        self.resources.write().deactivate_all();
990    }
991
992    /// Isolate an [`Amplitude`] by name (deactivate the rest), skipping missing entries.
993    pub fn isolate<T: AsRef<str>>(&self, name: T) {
994        self.resources.write().isolate(name);
995    }
996
997    /// Isolate an [`Amplitude`] by name (deactivate the rest) and return an error if it is missing.
998    pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
999        self.resources.write().isolate_strict(name)
1000    }
1001
1002    /// Isolate several [`Amplitude`]s by name (deactivate the rest), skipping missing entries.
1003    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
1004        self.resources.write().isolate_many(names);
1005    }
1006
1007    /// Isolate several [`Amplitude`]s by name (deactivate the rest) and return an error if any are missing.
1008    pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
1009        self.resources.write().isolate_many_strict(names)
1010    }
1011
1012    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
1013    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
1014    ///
1015    /// # Notes
1016    ///
1017    /// This method is not intended to be called in analyses but rather in writing methods
1018    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
1019    pub fn evaluate_local(&self, parameters: &[f64]) -> Vec<Complex64> {
1020        let resources = self.resources.read();
1021        let parameters = Parameters::new(parameters, &resources.constants);
1022        #[cfg(feature = "rayon")]
1023        {
1024            self.dataset
1025                .events
1026                .par_iter()
1027                .zip(resources.caches.par_iter())
1028                .map(|(event, cache)| {
1029                    let amplitude_values: Vec<Complex64> = self
1030                        .amplitudes
1031                        .iter()
1032                        .zip(resources.active.iter())
1033                        .map(|(amp, active)| {
1034                            if *active {
1035                                amp.compute(&parameters, event, cache)
1036                            } else {
1037                                Complex64::ZERO
1038                            }
1039                        })
1040                        .collect();
1041                    self.expression.evaluate(&amplitude_values)
1042                })
1043                .collect()
1044        }
1045        #[cfg(not(feature = "rayon"))]
1046        {
1047            self.dataset
1048                .events
1049                .iter()
1050                .zip(resources.caches.iter())
1051                .map(|(event, cache)| {
1052                    let amplitude_values: Vec<Complex64> = self
1053                        .amplitudes
1054                        .iter()
1055                        .zip(resources.active.iter())
1056                        .map(|(amp, active)| {
1057                            if *active {
1058                                amp.compute(&parameters, event, cache)
1059                            } else {
1060                                Complex64::ZERO
1061                            }
1062                        })
1063                        .collect();
1064                    self.expression.evaluate(&amplitude_values)
1065                })
1066                .collect()
1067        }
1068    }
1069
1070    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
1071    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
1072    ///
1073    /// # Notes
1074    ///
1075    /// This method is not intended to be called in analyses but rather in writing methods
1076    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
1077    #[cfg(feature = "mpi")]
1078    fn evaluate_mpi(&self, parameters: &[f64], world: &SimpleCommunicator) -> Vec<Complex64> {
1079        let local_evaluation = self.evaluate_local(parameters);
1080        let n_events = self.dataset.n_events();
1081        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
1082        let (counts, displs) = world.get_counts_displs(n_events);
1083        {
1084            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
1085            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
1086        }
1087        buffer
1088    }
1089
1090    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
1091    /// [`Evaluator`] with the given values for free parameters.
1092    pub fn evaluate(&self, parameters: &[f64]) -> Vec<Complex64> {
1093        #[cfg(feature = "mpi")]
1094        {
1095            if let Some(world) = crate::mpi::get_world() {
1096                return self.evaluate_mpi(parameters, &world);
1097            }
1098        }
1099        self.evaluate_local(parameters)
1100    }
1101
1102    /// See [`Evaluator::evaluate_local`]. This method evaluates over a subset of events rather
1103    /// than all events in the total dataset.
1104    pub fn evaluate_batch_local(&self, parameters: &[f64], indices: &[usize]) -> Vec<Complex64> {
1105        let resources = self.resources.read();
1106        let parameters = Parameters::new(parameters, &resources.constants);
1107        #[cfg(feature = "rayon")]
1108        {
1109            indices
1110                .par_iter()
1111                .map(|&idx| {
1112                    let event = &self.dataset.events[idx];
1113                    let cache = &resources.caches[idx];
1114                    let amplitude_values: Vec<Complex64> = self
1115                        .amplitudes
1116                        .iter()
1117                        .zip(resources.active.iter())
1118                        .map(|(amp, active)| {
1119                            if *active {
1120                                amp.compute(&parameters, event, cache)
1121                            } else {
1122                                Complex64::ZERO
1123                            }
1124                        })
1125                        .collect();
1126                    self.expression.evaluate(&amplitude_values)
1127                })
1128                .collect()
1129        }
1130        #[cfg(not(feature = "rayon"))]
1131        {
1132            indices
1133                .iter()
1134                .map(|&idx| {
1135                    let event = &self.dataset.events[idx];
1136                    let cache = &resources.caches[idx];
1137                    let amplitude_values: Vec<Complex64> = self
1138                        .amplitudes
1139                        .iter()
1140                        .zip(resources.active.iter())
1141                        .map(|(amp, active)| {
1142                            if *active {
1143                                amp.compute(&parameters, event, cache)
1144                            } else {
1145                                Complex64::ZERO
1146                            }
1147                        })
1148                        .collect();
1149                    self.expression.evaluate(&amplitude_values)
1150                })
1151                .collect()
1152        }
1153    }
1154
1155    /// See [`Evaluator::evaluate_mpi`]. This method evaluates over a subset of events rather
1156    /// than all events in the total dataset.
1157    #[cfg(feature = "mpi")]
1158    fn evaluate_batch_mpi(
1159        &self,
1160        parameters: &[f64],
1161        indices: &[usize],
1162        world: &SimpleCommunicator,
1163    ) -> Vec<Complex64> {
1164        let total = self.dataset.n_events();
1165        let locals = world.locals_from_globals(indices, total);
1166        let local_evaluation = self.evaluate_batch_local(parameters, &locals);
1167        world.all_gather_batched_partitioned(&local_evaluation, indices, total, None)
1168    }
1169
1170    /// Evaluate the stored [`Expression`] over a subset of events in the [`Dataset`] stored by the
1171    /// [`Evaluator`] with the given values for free parameters. See also [`Expression::evaluate`].
1172    pub fn evaluate_batch(&self, parameters: &[f64], indices: &[usize]) -> Vec<Complex64> {
1173        #[cfg(feature = "mpi")]
1174        {
1175            if let Some(world) = crate::mpi::get_world() {
1176                return self.evaluate_batch_mpi(parameters, indices, &world);
1177            }
1178        }
1179        self.evaluate_batch_local(parameters, indices)
1180    }
1181
1182    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
1183    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
1184    ///
1185    /// # Notes
1186    ///
1187    /// This method is not intended to be called in analyses but rather in writing methods
1188    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
1189    pub fn evaluate_gradient_local(&self, parameters: &[f64]) -> Vec<DVector<Complex64>> {
1190        let resources = self.resources.read();
1191        let parameters = Parameters::new(parameters, &resources.constants);
1192        #[cfg(feature = "rayon")]
1193        {
1194            self.dataset
1195                .events
1196                .par_iter()
1197                .zip(resources.caches.par_iter())
1198                .map(|(event, cache)| {
1199                    let mut gradient_values =
1200                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1201                    self.amplitudes
1202                        .iter()
1203                        .zip(resources.active.iter())
1204                        .zip(gradient_values.iter_mut())
1205                        .for_each(|((amp, active), grad)| {
1206                            if *active {
1207                                amp.compute_gradient(&parameters, event, cache, grad)
1208                            }
1209                        });
1210                    let amplitude_values: Vec<Complex64> = self
1211                        .amplitudes
1212                        .iter()
1213                        .zip(resources.active.iter())
1214                        .map(|(amp, active)| {
1215                            if *active {
1216                                amp.compute(&parameters, event, cache)
1217                            } else {
1218                                Complex64::ZERO
1219                            }
1220                        })
1221                        .collect();
1222                    self.expression
1223                        .evaluate_gradient(&amplitude_values, &gradient_values)
1224                })
1225                .collect()
1226        }
1227        #[cfg(not(feature = "rayon"))]
1228        {
1229            self.dataset
1230                .events
1231                .iter()
1232                .zip(resources.caches.iter())
1233                .map(|(event, cache)| {
1234                    let mut gradient_values =
1235                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1236                    self.amplitudes
1237                        .iter()
1238                        .zip(resources.active.iter())
1239                        .zip(gradient_values.iter_mut())
1240                        .for_each(|((amp, active), grad)| {
1241                            if *active {
1242                                amp.compute_gradient(&parameters, event, cache, grad)
1243                            }
1244                        });
1245                    let amplitude_values: Vec<Complex64> = self
1246                        .amplitudes
1247                        .iter()
1248                        .zip(resources.active.iter())
1249                        .map(|(amp, active)| {
1250                            if *active {
1251                                amp.compute(&parameters, event, cache)
1252                            } else {
1253                                Complex64::ZERO
1254                            }
1255                        })
1256                        .collect();
1257
1258                    self.expression
1259                        .evaluate_gradient(&amplitude_values, &gradient_values)
1260                })
1261                .collect()
1262        }
1263    }
1264
1265    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
1266    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
1267    ///
1268    /// # Notes
1269    ///
1270    /// This method is not intended to be called in analyses but rather in writing methods
1271    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
1272    #[cfg(feature = "mpi")]
1273    fn evaluate_gradient_mpi(
1274        &self,
1275        parameters: &[f64],
1276        world: &SimpleCommunicator,
1277    ) -> Vec<DVector<Complex64>> {
1278        let local_evaluation = self.evaluate_gradient_local(parameters);
1279        let n_events = self.dataset.n_events();
1280        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
1281        let (counts, displs) = world.get_counts_displs(n_events);
1282        {
1283            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
1284            world.all_gather_varcount_into(
1285                &local_evaluation
1286                    .iter()
1287                    .flat_map(|v| v.data.as_vec())
1288                    .copied()
1289                    .collect::<Vec<_>>(),
1290                &mut partitioned_buffer,
1291            );
1292        }
1293        buffer
1294            .chunks(parameters.len())
1295            .map(|chunk| DVector::from_row_slice(chunk))
1296            .collect()
1297    }
1298
1299    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
1300    /// [`Evaluator`] with the given values for free parameters.
1301    pub fn evaluate_gradient(&self, parameters: &[f64]) -> Vec<DVector<Complex64>> {
1302        #[cfg(feature = "mpi")]
1303        {
1304            if let Some(world) = crate::mpi::get_world() {
1305                return self.evaluate_gradient_mpi(parameters, &world);
1306            }
1307        }
1308        self.evaluate_gradient_local(parameters)
1309    }
1310
1311    /// See [`Evaluator::evaluate_gradient_local`]. This method evaluates over a subset
1312    /// of events rather than all events in the total dataset.
1313    pub fn evaluate_gradient_batch_local(
1314        &self,
1315        parameters: &[f64],
1316        indices: &[usize],
1317    ) -> Vec<DVector<Complex64>> {
1318        let resources = self.resources.read();
1319        let parameters = Parameters::new(parameters, &resources.constants);
1320        #[cfg(feature = "rayon")]
1321        {
1322            indices
1323                .par_iter()
1324                .map(|&idx| {
1325                    let event = &self.dataset.events[idx];
1326                    let cache = &resources.caches[idx];
1327                    let mut gradient_values =
1328                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1329                    self.amplitudes
1330                        .iter()
1331                        .zip(resources.active.iter())
1332                        .zip(gradient_values.iter_mut())
1333                        .for_each(|((amp, active), grad)| {
1334                            if *active {
1335                                amp.compute_gradient(&parameters, event, cache, grad)
1336                            }
1337                        });
1338                    let amplitude_values: Vec<Complex64> = self
1339                        .amplitudes
1340                        .iter()
1341                        .zip(resources.active.iter())
1342                        .map(|(amp, active)| {
1343                            if *active {
1344                                amp.compute(&parameters, event, cache)
1345                            } else {
1346                                Complex64::ZERO
1347                            }
1348                        })
1349                        .collect();
1350                    self.expression
1351                        .evaluate_gradient(&amplitude_values, &gradient_values)
1352                })
1353                .collect()
1354        }
1355        #[cfg(not(feature = "rayon"))]
1356        {
1357            indices
1358                .iter()
1359                .map(|&idx| {
1360                    let event = &self.dataset.events[idx];
1361                    let cache = &resources.caches[idx];
1362                    let mut gradient_values =
1363                        vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1364                    self.amplitudes
1365                        .iter()
1366                        .zip(resources.active.iter())
1367                        .zip(gradient_values.iter_mut())
1368                        .for_each(|((amp, active), grad)| {
1369                            if *active {
1370                                amp.compute_gradient(&parameters, event, cache, grad)
1371                            }
1372                        });
1373                    let amplitude_values: Vec<Complex64> = self
1374                        .amplitudes
1375                        .iter()
1376                        .zip(resources.active.iter())
1377                        .map(|(amp, active)| {
1378                            if *active {
1379                                amp.compute(&parameters, event, cache)
1380                            } else {
1381                                Complex64::ZERO
1382                            }
1383                        })
1384                        .collect();
1385
1386                    self.expression
1387                        .evaluate_gradient(&amplitude_values, &gradient_values)
1388                })
1389                .collect()
1390        }
1391    }
1392
1393    /// See [`Evaluator::evaluate_gradient_mpi`]. This method evaluates over a subset
1394    /// of events rather than all events in the total dataset.
1395    #[cfg(feature = "mpi")]
1396    fn evaluate_gradient_batch_mpi(
1397        &self,
1398        parameters: &[f64],
1399        indices: &[usize],
1400        world: &SimpleCommunicator,
1401    ) -> Vec<DVector<Complex64>> {
1402        let total = self.dataset.n_events();
1403        let locals = world.locals_from_globals(indices, total);
1404        let flattened_local_evaluation = self
1405            .evaluate_gradient_batch_local(parameters, &locals)
1406            .iter()
1407            .flat_map(|g| g.data.as_vec().to_vec())
1408            .collect::<Vec<Complex64>>();
1409        world
1410            .all_gather_batched_partitioned(
1411                &flattened_local_evaluation,
1412                indices,
1413                total,
1414                Some(parameters.len()),
1415            )
1416            .chunks(parameters.len())
1417            .map(DVector::from_row_slice)
1418            .collect()
1419    }
1420
1421    /// Evaluate the gradient of the stored [`Expression`] over a subset of the
1422    /// events in the [`Dataset`] stored by the [`Evaluator`] with the given values
1423    /// for free parameters. See also [`Expression::evaluate_gradient`].
1424    pub fn evaluate_gradient_batch(
1425        &self,
1426        parameters: &[f64],
1427        indices: &[usize],
1428    ) -> Vec<DVector<Complex64>> {
1429        #[cfg(feature = "mpi")]
1430        {
1431            if let Some(world) = crate::mpi::get_world() {
1432                return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
1433            }
1434        }
1435        self.evaluate_gradient_batch_local(parameters, indices)
1436    }
1437}
1438
1439/// A testing [`Amplitude`].
1440#[derive(Clone, Serialize, Deserialize)]
1441pub struct TestAmplitude {
1442    name: String,
1443    re: ParameterLike,
1444    pid_re: ParameterID,
1445    im: ParameterLike,
1446    pid_im: ParameterID,
1447}
1448
1449impl TestAmplitude {
1450    /// Create a new testing [`Amplitude`].
1451    #[allow(clippy::new_ret_no_self)]
1452    pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> LadduResult<Expression> {
1453        Self {
1454            name: name.to_string(),
1455            re,
1456            pid_re: Default::default(),
1457            im,
1458            pid_im: Default::default(),
1459        }
1460        .into_expression()
1461    }
1462}
1463
1464#[typetag::serde]
1465impl Amplitude for TestAmplitude {
1466    fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
1467        self.pid_re = resources.register_parameter(&self.re)?;
1468        self.pid_im = resources.register_parameter(&self.im)?;
1469        resources.register_amplitude(&self.name)
1470    }
1471
1472    fn compute(&self, parameters: &Parameters, event: &EventData, _cache: &Cache) -> Complex64 {
1473        Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im)) * event.p4s[0].e()
1474    }
1475
1476    fn compute_gradient(
1477        &self,
1478        _parameters: &Parameters,
1479        event: &EventData,
1480        _cache: &Cache,
1481        gradient: &mut DVector<Complex64>,
1482    ) {
1483        if let ParameterID::Parameter(ind) = self.pid_re {
1484            gradient[ind] = Complex64::ONE * event.p4s[0].e();
1485        }
1486        if let ParameterID::Parameter(ind) = self.pid_im {
1487            gradient[ind] = Complex64::I * event.p4s[0].e();
1488        }
1489    }
1490}
1491
1492#[cfg(test)]
1493mod tests {
1494    use crate::data::{test_dataset, test_event, DatasetMetadata};
1495
1496    use super::*;
1497    use crate::{
1498        data::EventData,
1499        resources::{Cache, ParameterID, Parameters, Resources},
1500    };
1501    use approx::assert_relative_eq;
1502    use serde::{Deserialize, Serialize};
1503
1504    #[derive(Clone, Serialize, Deserialize)]
1505    pub struct ComplexScalar {
1506        name: String,
1507        re: ParameterLike,
1508        pid_re: ParameterID,
1509        im: ParameterLike,
1510        pid_im: ParameterID,
1511    }
1512
1513    impl ComplexScalar {
1514        #[allow(clippy::new_ret_no_self)]
1515        pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> LadduResult<Expression> {
1516            Self {
1517                name: name.to_string(),
1518                re,
1519                pid_re: Default::default(),
1520                im,
1521                pid_im: Default::default(),
1522            }
1523            .into_expression()
1524        }
1525    }
1526
1527    #[typetag::serde]
1528    impl Amplitude for ComplexScalar {
1529        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
1530            self.pid_re = resources.register_parameter(&self.re)?;
1531            self.pid_im = resources.register_parameter(&self.im)?;
1532            resources.register_amplitude(&self.name)
1533        }
1534
1535        fn compute(
1536            &self,
1537            parameters: &Parameters,
1538            _event: &EventData,
1539            _cache: &Cache,
1540        ) -> Complex64 {
1541            Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
1542        }
1543
1544        fn compute_gradient(
1545            &self,
1546            _parameters: &Parameters,
1547            _event: &EventData,
1548            _cache: &Cache,
1549            gradient: &mut DVector<Complex64>,
1550        ) {
1551            if let ParameterID::Parameter(ind) = self.pid_re {
1552                gradient[ind] = Complex64::ONE;
1553            }
1554            if let ParameterID::Parameter(ind) = self.pid_im {
1555                gradient[ind] = Complex64::I;
1556            }
1557        }
1558    }
1559
1560    #[test]
1561    fn test_batch_evaluation() {
1562        let expr = TestAmplitude::new("test", parameter("real"), parameter("imag")).unwrap();
1563        let mut event1 = test_event();
1564        event1.p4s[0].t = 10.0;
1565        let mut event2 = test_event();
1566        event2.p4s[0].t = 11.0;
1567        let mut event3 = test_event();
1568        event3.p4s[0].t = 12.0;
1569        let dataset = Arc::new(Dataset::new_with_metadata(
1570            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
1571            Arc::new(DatasetMetadata::default()),
1572        ));
1573        let evaluator = expr.load(&dataset).unwrap();
1574        let result = evaluator.evaluate_batch(&[1.1, 2.2], &[0, 2]);
1575        assert_eq!(result.len(), 2);
1576        assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
1577        assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
1578        let result_grad = evaluator.evaluate_gradient_batch(&[1.1, 2.2], &[0, 2]);
1579        assert_eq!(result_grad.len(), 2);
1580        assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
1581        assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
1582        assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
1583        assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
1584    }
1585
1586    #[test]
1587    fn test_constant_amplitude() {
1588        let expr = ComplexScalar::new(
1589            "constant",
1590            constant("const_re", 2.0),
1591            constant("const_im", 3.0),
1592        )
1593        .unwrap();
1594        let dataset = Arc::new(Dataset::new_with_metadata(
1595            vec![Arc::new(test_event())],
1596            Arc::new(DatasetMetadata::default()),
1597        ));
1598        let evaluator = expr.load(&dataset).unwrap();
1599        let result = evaluator.evaluate(&[]);
1600        assert_eq!(result[0], Complex64::new(2.0, 3.0));
1601    }
1602
1603    #[test]
1604    fn test_parametric_amplitude() {
1605        let expr = ComplexScalar::new(
1606            "parametric",
1607            parameter("test_param_re"),
1608            parameter("test_param_im"),
1609        )
1610        .unwrap();
1611        let dataset = Arc::new(test_dataset());
1612        let evaluator = expr.load(&dataset).unwrap();
1613        let result = evaluator.evaluate(&[2.0, 3.0]);
1614        assert_eq!(result[0], Complex64::new(2.0, 3.0));
1615    }
1616
1617    #[test]
1618    fn test_expression_operations() {
1619        let expr1 = ComplexScalar::new(
1620            "const1",
1621            constant("const1_re", 2.0),
1622            constant("const1_im", 0.0),
1623        )
1624        .unwrap();
1625        let expr2 = ComplexScalar::new(
1626            "const2",
1627            constant("const2_re", 0.0),
1628            constant("const2_im", 1.0),
1629        )
1630        .unwrap();
1631        let expr3 = ComplexScalar::new(
1632            "const3",
1633            constant("const3_re", 3.0),
1634            constant("const3_im", 4.0),
1635        )
1636        .unwrap();
1637
1638        let dataset = Arc::new(test_dataset());
1639
1640        // Test (amp) addition
1641        let expr_add = &expr1 + &expr2;
1642        let result_add = expr_add.load(&dataset).unwrap().evaluate(&[]);
1643        assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
1644
1645        // Test (amp) subtraction
1646        let expr_sub = &expr1 - &expr2;
1647        let result_sub = expr_sub.load(&dataset).unwrap().evaluate(&[]);
1648        assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
1649
1650        // Test (amp) multiplication
1651        let expr_mul = &expr1 * &expr2;
1652        let result_mul = expr_mul.load(&dataset).unwrap().evaluate(&[]);
1653        assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
1654
1655        // Test (amp) division
1656        let expr_div = &expr1 / &expr3;
1657        let result_div = expr_div.load(&dataset).unwrap().evaluate(&[]);
1658        assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
1659
1660        // Test (amp) neg
1661        let expr_neg = -&expr3;
1662        let result_neg = expr_neg.load(&dataset).unwrap().evaluate(&[]);
1663        assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
1664
1665        // Test (expr) addition
1666        let expr_add2 = &expr_add + &expr_mul;
1667        let result_add2 = expr_add2.load(&dataset).unwrap().evaluate(&[]);
1668        assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
1669
1670        // Test (expr) subtraction
1671        let expr_sub2 = &expr_add - &expr_mul;
1672        let result_sub2 = expr_sub2.load(&dataset).unwrap().evaluate(&[]);
1673        assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
1674
1675        // Test (expr) multiplication
1676        let expr_mul2 = &expr_add * &expr_mul;
1677        let result_mul2 = expr_mul2.load(&dataset).unwrap().evaluate(&[]);
1678        assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
1679
1680        // Test (expr) division
1681        let expr_div2 = &expr_add / &expr_add2;
1682        let result_div2 = expr_div2.load(&dataset).unwrap().evaluate(&[]);
1683        assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
1684
1685        // Test (expr) neg
1686        let expr_neg2 = -&expr_mul2;
1687        let result_neg2 = expr_neg2.load(&dataset).unwrap().evaluate(&[]);
1688        assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
1689
1690        // Test (amp) real
1691        let expr_real = expr3.real();
1692        let result_real = expr_real.load(&dataset).unwrap().evaluate(&[]);
1693        assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
1694
1695        // Test (expr) real
1696        let expr_mul2_real = expr_mul2.real();
1697        let result_mul2_real = expr_mul2_real.load(&dataset).unwrap().evaluate(&[]);
1698        assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
1699
1700        // Test (amp) imag
1701        let expr_imag = expr3.imag();
1702        let result_imag = expr_imag.load(&dataset).unwrap().evaluate(&[]);
1703        assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
1704
1705        // Test (expr) imag
1706        let expr_mul2_imag = expr_mul2.imag();
1707        let result_mul2_imag = expr_mul2_imag.load(&dataset).unwrap().evaluate(&[]);
1708        assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
1709
1710        // Test (amp) conj
1711        let expr_conj = expr3.conj();
1712        let result_conj = expr_conj.load(&dataset).unwrap().evaluate(&[]);
1713        assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
1714
1715        // Test (expr) conj
1716        let expr_mul2_conj = expr_mul2.conj();
1717        let result_mul2_conj = expr_mul2_conj.load(&dataset).unwrap().evaluate(&[]);
1718        assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
1719
1720        // Test (amp) norm_sqr
1721        let expr_norm = expr1.norm_sqr();
1722        let result_norm = expr_norm.load(&dataset).unwrap().evaluate(&[]);
1723        assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
1724
1725        // Test (expr) norm_sqr
1726        let expr_mul2_norm = expr_mul2.norm_sqr();
1727        let result_mul2_norm = expr_mul2_norm.load(&dataset).unwrap().evaluate(&[]);
1728        assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
1729    }
1730
1731    #[test]
1732    fn test_amplitude_activation() {
1733        let expr1 = ComplexScalar::new(
1734            "const1",
1735            constant("const1_re_act", 1.0),
1736            constant("const1_im_act", 0.0),
1737        )
1738        .unwrap();
1739        let expr2 = ComplexScalar::new(
1740            "const2",
1741            constant("const2_re_act", 2.0),
1742            constant("const2_im_act", 0.0),
1743        )
1744        .unwrap();
1745
1746        let dataset = Arc::new(test_dataset());
1747        let expr = &expr1 + &expr2;
1748        let evaluator = expr.load(&dataset).unwrap();
1749
1750        // Test initial state (all active)
1751        let result = evaluator.evaluate(&[]);
1752        assert_eq!(result[0], Complex64::new(3.0, 0.0));
1753
1754        // Test deactivation
1755        evaluator.deactivate_strict("const1").unwrap();
1756        let result = evaluator.evaluate(&[]);
1757        assert_eq!(result[0], Complex64::new(2.0, 0.0));
1758
1759        // Test isolation
1760        evaluator.isolate_strict("const1").unwrap();
1761        let result = evaluator.evaluate(&[]);
1762        assert_eq!(result[0], Complex64::new(1.0, 0.0));
1763
1764        // Test reactivation
1765        evaluator.activate_all();
1766        let result = evaluator.evaluate(&[]);
1767        assert_eq!(result[0], Complex64::new(3.0, 0.0));
1768    }
1769
1770    #[test]
1771    fn test_gradient() {
1772        let expr1 = ComplexScalar::new(
1773            "parametric_1",
1774            parameter("test_param_re_1"),
1775            parameter("test_param_im_1"),
1776        )
1777        .unwrap();
1778        let expr2 = ComplexScalar::new(
1779            "parametric_2",
1780            parameter("test_param_re_2"),
1781            parameter("test_param_im_2"),
1782        )
1783        .unwrap();
1784
1785        let dataset = Arc::new(test_dataset());
1786        let params = vec![2.0, 3.0, 4.0, 5.0];
1787
1788        let expr = &expr1 + &expr2;
1789        let evaluator = expr.load(&dataset).unwrap();
1790
1791        let gradient = evaluator.evaluate_gradient(&params);
1792
1793        assert_relative_eq!(gradient[0][0].re, 1.0);
1794        assert_relative_eq!(gradient[0][0].im, 0.0);
1795        assert_relative_eq!(gradient[0][1].re, 0.0);
1796        assert_relative_eq!(gradient[0][1].im, 1.0);
1797        assert_relative_eq!(gradient[0][2].re, 1.0);
1798        assert_relative_eq!(gradient[0][2].im, 0.0);
1799        assert_relative_eq!(gradient[0][3].re, 0.0);
1800        assert_relative_eq!(gradient[0][3].im, 1.0);
1801
1802        let expr = &expr1 - &expr2;
1803        let evaluator = expr.load(&dataset).unwrap();
1804
1805        let gradient = evaluator.evaluate_gradient(&params);
1806
1807        assert_relative_eq!(gradient[0][0].re, 1.0);
1808        assert_relative_eq!(gradient[0][0].im, 0.0);
1809        assert_relative_eq!(gradient[0][1].re, 0.0);
1810        assert_relative_eq!(gradient[0][1].im, 1.0);
1811        assert_relative_eq!(gradient[0][2].re, -1.0);
1812        assert_relative_eq!(gradient[0][2].im, 0.0);
1813        assert_relative_eq!(gradient[0][3].re, 0.0);
1814        assert_relative_eq!(gradient[0][3].im, -1.0);
1815
1816        let expr = &expr1 * &expr2;
1817        let evaluator = expr.load(&dataset).unwrap();
1818
1819        let gradient = evaluator.evaluate_gradient(&params);
1820
1821        assert_relative_eq!(gradient[0][0].re, 4.0);
1822        assert_relative_eq!(gradient[0][0].im, 5.0);
1823        assert_relative_eq!(gradient[0][1].re, -5.0);
1824        assert_relative_eq!(gradient[0][1].im, 4.0);
1825        assert_relative_eq!(gradient[0][2].re, 2.0);
1826        assert_relative_eq!(gradient[0][2].im, 3.0);
1827        assert_relative_eq!(gradient[0][3].re, -3.0);
1828        assert_relative_eq!(gradient[0][3].im, 2.0);
1829
1830        let expr = &expr1 / &expr2;
1831        let evaluator = expr.load(&dataset).unwrap();
1832
1833        let gradient = evaluator.evaluate_gradient(&params);
1834
1835        assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
1836        assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
1837        assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
1838        assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
1839        assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
1840        assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
1841        assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
1842        assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
1843
1844        let expr = -(&expr1 * &expr2);
1845        let evaluator = expr.load(&dataset).unwrap();
1846
1847        let gradient = evaluator.evaluate_gradient(&params);
1848
1849        assert_relative_eq!(gradient[0][0].re, -4.0);
1850        assert_relative_eq!(gradient[0][0].im, -5.0);
1851        assert_relative_eq!(gradient[0][1].re, 5.0);
1852        assert_relative_eq!(gradient[0][1].im, -4.0);
1853        assert_relative_eq!(gradient[0][2].re, -2.0);
1854        assert_relative_eq!(gradient[0][2].im, -3.0);
1855        assert_relative_eq!(gradient[0][3].re, 3.0);
1856        assert_relative_eq!(gradient[0][3].im, -2.0);
1857
1858        let expr = (&expr1 * &expr2).real();
1859        let evaluator = expr.load(&dataset).unwrap();
1860
1861        let gradient = evaluator.evaluate_gradient(&params);
1862
1863        assert_relative_eq!(gradient[0][0].re, 4.0);
1864        assert_relative_eq!(gradient[0][0].im, 0.0);
1865        assert_relative_eq!(gradient[0][1].re, -5.0);
1866        assert_relative_eq!(gradient[0][1].im, 0.0);
1867        assert_relative_eq!(gradient[0][2].re, 2.0);
1868        assert_relative_eq!(gradient[0][2].im, 0.0);
1869        assert_relative_eq!(gradient[0][3].re, -3.0);
1870        assert_relative_eq!(gradient[0][3].im, 0.0);
1871
1872        let expr = (&expr1 * &expr2).imag();
1873        let evaluator = expr.load(&dataset).unwrap();
1874
1875        let gradient = evaluator.evaluate_gradient(&params);
1876
1877        assert_relative_eq!(gradient[0][0].re, 5.0);
1878        assert_relative_eq!(gradient[0][0].im, 0.0);
1879        assert_relative_eq!(gradient[0][1].re, 4.0);
1880        assert_relative_eq!(gradient[0][1].im, 0.0);
1881        assert_relative_eq!(gradient[0][2].re, 3.0);
1882        assert_relative_eq!(gradient[0][2].im, 0.0);
1883        assert_relative_eq!(gradient[0][3].re, 2.0);
1884        assert_relative_eq!(gradient[0][3].im, 0.0);
1885
1886        let expr = (&expr1 * &expr2).conj();
1887        let evaluator = expr.load(&dataset).unwrap();
1888
1889        let gradient = evaluator.evaluate_gradient(&params);
1890
1891        assert_relative_eq!(gradient[0][0].re, 4.0);
1892        assert_relative_eq!(gradient[0][0].im, -5.0);
1893        assert_relative_eq!(gradient[0][1].re, -5.0);
1894        assert_relative_eq!(gradient[0][1].im, -4.0);
1895        assert_relative_eq!(gradient[0][2].re, 2.0);
1896        assert_relative_eq!(gradient[0][2].im, -3.0);
1897        assert_relative_eq!(gradient[0][3].re, -3.0);
1898        assert_relative_eq!(gradient[0][3].im, -2.0);
1899
1900        let expr = (&expr1 * &expr2).norm_sqr();
1901        let evaluator = expr.load(&dataset).unwrap();
1902
1903        let gradient = evaluator.evaluate_gradient(&params);
1904
1905        assert_relative_eq!(gradient[0][0].re, 164.0);
1906        assert_relative_eq!(gradient[0][0].im, 0.0);
1907        assert_relative_eq!(gradient[0][1].re, 246.0);
1908        assert_relative_eq!(gradient[0][1].im, 0.0);
1909        assert_relative_eq!(gradient[0][2].re, 104.0);
1910        assert_relative_eq!(gradient[0][2].im, 0.0);
1911        assert_relative_eq!(gradient[0][3].re, 130.0);
1912        assert_relative_eq!(gradient[0][3].im, 0.0);
1913    }
1914
1915    #[test]
1916    fn test_zeros_and_ones() {
1917        let amp = ComplexScalar::new(
1918            "parametric",
1919            parameter("test_param_re"),
1920            constant("fixed_two", 2.0),
1921        )
1922        .unwrap();
1923        let dataset = Arc::new(test_dataset());
1924        let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
1925        let evaluator = expr.load(&dataset).unwrap();
1926
1927        let params = vec![2.0];
1928        let value = evaluator.evaluate(&params);
1929        let gradient = evaluator.evaluate_gradient(&params);
1930
1931        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the value should be x^2 + 4
1932        assert_relative_eq!(value[0].re, 8.0);
1933        assert_relative_eq!(value[0].im, 0.0);
1934
1935        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the derivative should be 2x
1936        assert_relative_eq!(gradient[0][0].re, 4.0);
1937        assert_relative_eq!(gradient[0][0].im, 0.0);
1938    }
1939
1940    #[test]
1941    fn test_parameter_registration() {
1942        let expr = ComplexScalar::new(
1943            "parametric",
1944            parameter("test_param_re"),
1945            constant("fixed_two", 2.0),
1946        )
1947        .unwrap();
1948        let parameters = expr.free_parameters();
1949        assert_eq!(parameters.len(), 1);
1950        assert_eq!(parameters[0], "test_param_re");
1951    }
1952
1953    #[test]
1954    #[should_panic(expected = "refers to different underlying amplitudes")]
1955    fn test_duplicate_amplitude_registration() {
1956        let amp1 = ComplexScalar::new(
1957            "same_name",
1958            constant("dup_re1", 1.0),
1959            constant("dup_im1", 0.0),
1960        )
1961        .unwrap();
1962        let amp2 = ComplexScalar::new(
1963            "same_name",
1964            constant("dup_re2", 2.0),
1965            constant("dup_im2", 0.0),
1966        )
1967        .unwrap();
1968        let _expr = amp1 + amp2;
1969    }
1970
1971    #[test]
1972    fn test_tree_printing() {
1973        let amp1 = ComplexScalar::new(
1974            "parametric_1",
1975            parameter("test_param_re_1"),
1976            parameter("test_param_im_1"),
1977        )
1978        .unwrap();
1979        let amp2 = ComplexScalar::new(
1980            "parametric_2",
1981            parameter("test_param_re_2"),
1982            parameter("test_param_im_2"),
1983        )
1984        .unwrap();
1985        let expr = &amp1.real() + &amp2.conj().imag() + Expression::one() * -Expression::zero()
1986            - Expression::zero() / Expression::one()
1987            + (&amp1 * &amp2).norm_sqr();
1988        assert_eq!(
1989            expr.to_string(),
1990            "+
1991├─ -
1992│  ├─ +
1993│  │  ├─ +
1994│  │  │  ├─ Re
1995│  │  │  │  └─ parametric_1(id=0)
1996│  │  │  └─ Im
1997│  │  │     └─ *
1998│  │  │        └─ parametric_2(id=1)
1999│  │  └─ ×
2000│  │     ├─ 1
2001│  │     └─ -
2002│  │        └─ 0
2003│  └─ ÷
2004│     ├─ 0
2005│     └─ 1
2006└─ NormSqr
2007   └─ ×
2008      ├─ parametric_1(id=0)
2009      └─ parametric_2(id=1)
2010"
2011        );
2012    }
2013}