Skip to main content

laddu_core/expression/
evaluator.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    sync::{
5        atomic::{AtomicU64, Ordering},
6        Arc,
7    },
8    time::Instant,
9};
10
11use auto_ops::*;
12use nalgebra::DVector;
13use num::complex::Complex64;
14use parking_lot::RwLock;
15#[cfg(feature = "rayon")]
16use rayon::prelude::*;
17use serde::{Deserialize, Serialize};
18
19use super::{ir, lowered};
20
21static COMPUTE_AMPLITUDE_COUNTER: AtomicU64 = AtomicU64::new(0);
22static AMPLITUDE_USE_SITE_COUNTER: AtomicU64 = AtomicU64::new(0);
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
25struct ComputeAmplitudeId(u64);
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
28struct AmplitudeUseSiteId(u64);
29
30fn next_compute_amplitude_id() -> ComputeAmplitudeId {
31    ComputeAmplitudeId(COMPUTE_AMPLITUDE_COUNTER.fetch_add(1, Ordering::Relaxed))
32}
33
34fn next_amplitude_use_site_id() -> AmplitudeUseSiteId {
35    AmplitudeUseSiteId(AMPLITUDE_USE_SITE_COUNTER.fetch_add(1, Ordering::Relaxed))
36}
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38/// Dependence classification used by expression-IR diagnostics.
39pub enum ExpressionDependence {
40    /// Depends only on fixed/free parameter values.
41    ParameterOnly,
42    /// Depends only on event-local cached values.
43    CacheOnly,
44    /// Depends on both parameter values and cached event values.
45    Mixed,
46}
47impl From<ir::DependenceClass> for ExpressionDependence {
48    fn from(value: ir::DependenceClass) -> Self {
49        match value {
50            ir::DependenceClass::ParameterOnly => Self::ParameterOnly,
51            ir::DependenceClass::CacheOnly => Self::CacheOnly,
52            ir::DependenceClass::Mixed => Self::Mixed,
53        }
54    }
55}
56#[derive(Clone, Debug, PartialEq, Eq)]
57/// Explain/debug view of the IR normalization planning decomposition.
58pub struct NormalizationPlanExplain {
59    /// Dependence classification at the expression root.
60    pub root_dependence: ExpressionDependence,
61    /// Warning-level diagnostics collected during planning.
62    pub warnings: Vec<String>,
63    /// Candidate multiply node indices identified as separable.
64    pub separable_mul_candidate_nodes: Vec<usize>,
65    /// Candidate separable node indices selected for caching.
66    pub cached_separable_nodes: Vec<usize>,
67    /// Node indices planned for residual per-event evaluation.
68    pub residual_terms: Vec<usize>,
69}
70#[derive(Clone, Debug, PartialEq, Eq)]
71/// Explain/debug view of amplitude execution sets used by normalization evaluation.
72pub struct NormalizationExecutionSetsExplain {
73    /// Amplitudes required to evaluate parameter factors for cached separable terms.
74    pub cached_parameter_amplitudes: Vec<usize>,
75    /// Amplitudes required to evaluate cache factors for cached separable terms.
76    pub cached_cache_amplitudes: Vec<usize>,
77    /// Amplitudes required for residual (non-cached) normalization evaluation.
78    pub residual_amplitudes: Vec<usize>,
79}
80#[derive(Clone, Debug, PartialEq)]
81/// Load-time precomputed integral metadata for a separable cached term.
82pub struct PrecomputedCachedIntegral {
83    /// Node index of the separable multiplication term.
84    pub mul_node_index: usize,
85    /// Node index of the parameter-dependent factor.
86    pub parameter_node_index: usize,
87    /// Node index of the cache-dependent factor.
88    pub cache_node_index: usize,
89    /// Signed extraction coefficient induced by Add/Sub/Neg ancestry to the root.
90    pub coefficient: i32,
91    /// Weighted sum over local events of the cache-dependent factor.
92    pub weighted_cache_sum: Complex64,
93}
94#[derive(Clone, Debug, PartialEq)]
95/// Parameter-gradient contribution for a load-time precomputed cached integral term.
96pub struct PrecomputedCachedIntegralGradientTerm {
97    /// Node index of the separable multiplication term.
98    pub mul_node_index: usize,
99    /// Node index of the parameter-dependent factor.
100    pub parameter_node_index: usize,
101    /// Node index of the cache-dependent factor.
102    pub cache_node_index: usize,
103    /// Signed extraction coefficient induced by Add/Sub/Neg ancestry to the root.
104    pub coefficient: i32,
105    /// Gradient contribution `(d/dp parameter_factor) * weighted_cache_sum`.
106    pub weighted_gradient: DVector<Complex64>,
107}
108#[derive(Clone, Debug, PartialEq, Eq, Hash)]
109struct CachedIntegralCacheKey {
110    active_mask: Vec<bool>,
111    n_events_local: usize,
112    weights_local_len: usize,
113    weighted_sum_bits: u64,
114    weights_ptr: usize,
115}
116#[derive(Clone, Debug)]
117struct CachedIntegralCacheState {
118    key: CachedIntegralCacheKey,
119    expression_ir: ir::ExpressionIR,
120    values: Vec<PrecomputedCachedIntegral>,
121    execution_sets: ir::NormalizationExecutionSets,
122}
123#[derive(Clone, Debug)]
124struct LoweredArtifactCacheState {
125    parameter_node_indices: Vec<usize>,
126    mul_node_indices: Vec<usize>,
127    lowered_parameter_factors: Vec<Option<lowered::LoweredFactorRuntime>>,
128    residual_runtime: Option<lowered::LoweredExpressionRuntime>,
129    lowered_runtime: lowered::LoweredExpressionRuntime,
130}
131#[derive(Clone)]
132struct ExpressionSpecializationState {
133    cached_integrals: Arc<CachedIntegralCacheState>,
134    lowered_artifacts: Arc<LoweredArtifactCacheState>,
135}
136#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
137/// Debug/benchmark counters for active-mask specialization reuse under `expression-ir`.
138pub struct ExpressionSpecializationMetrics {
139    /// Number of specialization cache hits served without recompilation.
140    pub cache_hits: usize,
141    /// Number of specialization cache misses that required a fresh compile/lower pass.
142    pub cache_misses: usize,
143}
144#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
145/// Staged compile/lowering metrics for expression-IR construction and specialization refreshes.
146pub struct ExpressionCompileMetrics {
147    /// Nanoseconds spent compiling the semantic expression tree into IR during initial load.
148    pub initial_ir_compile_nanos: u64,
149    /// Nanoseconds spent precomputing cached-integral planning artifacts during initial load.
150    pub initial_cached_integrals_nanos: u64,
151    /// Nanoseconds spent lowering IR-derived runtimes during initial load.
152    pub initial_lowering_nanos: u64,
153    /// Number of specialization cache hits restored without recompilation.
154    pub specialization_cache_hits: usize,
155    /// Number of specialization cache misses that required recompilation.
156    pub specialization_cache_misses: usize,
157    /// Accumulated nanoseconds spent compiling active-mask-specialized IR after initial load.
158    pub specialization_ir_compile_nanos: u64,
159    /// Accumulated nanoseconds spent recomputing cached-integral planning artifacts after load.
160    pub specialization_cached_integrals_nanos: u64,
161    /// Accumulated nanoseconds spent lowering specialized runtimes after load.
162    pub specialization_lowering_nanos: u64,
163    /// Number of specialization rebuilds that reused cached lowered artifacts.
164    pub specialization_lowering_cache_hits: usize,
165    /// Number of specialization rebuilds that had to lower fresh artifacts.
166    pub specialization_lowering_cache_misses: usize,
167    /// Accumulated nanoseconds spent restoring specializations from cache.
168    pub specialization_cache_restore_nanos: u64,
169}
170impl From<ir::NormalizationPlanExplain> for NormalizationPlanExplain {
171    fn from(value: ir::NormalizationPlanExplain) -> Self {
172        Self {
173            root_dependence: value.root_dependence.into(),
174            warnings: value.warnings,
175            separable_mul_candidate_nodes: value
176                .separable_mul_candidates
177                .into_iter()
178                .map(|candidate| candidate.node_index)
179                .collect(),
180            cached_separable_nodes: value.cached_separable_nodes,
181            residual_terms: value.residual_terms,
182        }
183    }
184}
185impl From<ir::NormalizationExecutionSets> for NormalizationExecutionSetsExplain {
186    fn from(value: ir::NormalizationExecutionSets) -> Self {
187        Self {
188            cached_parameter_amplitudes: value.cached_parameter_amplitudes,
189            cached_cache_amplitudes: value.cached_cache_amplitudes,
190            residual_amplitudes: value.residual_amplitudes,
191        }
192    }
193}
194impl From<ExpressionDependence> for ir::DependenceClass {
195    fn from(value: ExpressionDependence) -> Self {
196        match value {
197            ExpressionDependence::ParameterOnly => Self::ParameterOnly,
198            ExpressionDependence::CacheOnly => Self::CacheOnly,
199            ExpressionDependence::Mixed => Self::Mixed,
200        }
201    }
202}
203
204#[cfg(feature = "mpi")]
205use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
206
207#[cfg(feature = "mpi")]
208use crate::mpi::LadduMPI;
209#[cfg(feature = "execution-context-prototype")]
210use crate::ExecutionContext;
211#[cfg(all(feature = "execution-context-prototype", feature = "rayon"))]
212use crate::ThreadPolicy;
213use crate::{
214    amplitude::{Amplitude, AmplitudeUseSite},
215    data::Dataset,
216    parameters::ParameterMap,
217    resources::{Cache, Parameters, Resources},
218    LadduError, LadduResult,
219};
220
221/// A holder struct that owns both an expression tree and the registered amplitudes.
222#[allow(missing_docs)]
223#[derive(Clone, Serialize, Deserialize, Default)]
224pub struct Expression {
225    registry: ExpressionRegistry,
226    tree: ExpressionNode,
227}
228
229#[derive(Clone, Serialize, Deserialize)]
230#[allow(missing_docs)]
231#[derive(Default)]
232pub struct ExpressionRegistry {
233    amplitudes: Vec<Box<dyn Amplitude>>,
234    amplitude_names: Vec<String>,
235    amplitude_ids: Vec<ComputeAmplitudeId>,
236    amplitude_use_sites: Vec<AmplitudeUseSite>,
237    amplitude_use_site_ids: Vec<AmplitudeUseSiteId>,
238    resources: Resources,
239}
240
241impl ExpressionRegistry {
242    fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
243        let mut resources = Resources::default();
244        let aid = amplitude.register(&mut resources)?;
245        let compute_id = next_compute_amplitude_id();
246        let use_site_id = next_amplitude_use_site_id();
247        resources.configure_amplitude_tags(std::slice::from_ref(&aid.0));
248        Ok(Self {
249            amplitudes: vec![amplitude],
250            amplitude_names: vec![aid.0.display_label()],
251            amplitude_ids: vec![compute_id],
252            amplitude_use_sites: vec![AmplitudeUseSite {
253                amplitude_index: 0,
254                tags: aid.0,
255            }],
256            amplitude_use_site_ids: vec![use_site_id],
257            resources,
258        })
259    }
260
261    fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
262        let mut resources = Resources::default();
263        let mut amplitudes = Vec::new();
264        let mut amplitude_ids = Vec::new();
265        let mut compute_id_to_index = HashMap::new();
266        let mut semantic_key_to_index = HashMap::new();
267
268        let mut left_compute_map = Vec::with_capacity(self.amplitudes.len());
269        for (amp_index, (amp, amp_id)) in
270            self.amplitudes.iter().zip(&self.amplitude_ids).enumerate()
271        {
272            let semantic_key = amp.semantic_key();
273            let mut cloned_amp = dyn_clone::clone_box(&**amp);
274            let aid = cloned_amp.register(&mut resources)?;
275            if let Some(key) = semantic_key.clone() {
276                semantic_key_to_index.insert(key, aid.1);
277            }
278            amplitudes.push(cloned_amp);
279            amplitude_ids.push(*amp_id);
280            compute_id_to_index.insert(*amp_id, aid.1);
281            left_compute_map.push(aid.1);
282            debug_assert_eq!(amp_index, aid.1);
283        }
284
285        let mut right_compute_map = Vec::with_capacity(other.amplitudes.len());
286        for (amp, amp_id) in other.amplitudes.iter().zip(&other.amplitude_ids) {
287            if let Some(existing) = compute_id_to_index.get(amp_id) {
288                right_compute_map.push(*existing);
289                continue;
290            }
291            let incoming_semantic_key = amp.semantic_key();
292            if let Some(existing) = incoming_semantic_key
293                .as_ref()
294                .and_then(|key| semantic_key_to_index.get(key))
295            {
296                right_compute_map.push(*existing);
297                continue;
298            }
299            let mut cloned_amp = dyn_clone::clone_box(&**amp);
300            let aid = cloned_amp.register(&mut resources)?;
301            if let Some(key) = incoming_semantic_key.clone() {
302                semantic_key_to_index.insert(key, aid.1);
303            }
304            amplitudes.push(cloned_amp);
305            amplitude_ids.push(*amp_id);
306            compute_id_to_index.insert(*amp_id, aid.1);
307            right_compute_map.push(aid.1);
308        }
309
310        let mut amplitude_use_sites = Vec::new();
311        let mut amplitude_use_site_ids = Vec::new();
312        let mut amplitude_names = Vec::new();
313        let mut use_site_id_to_index = HashMap::new();
314
315        let mut left_map = Vec::with_capacity(self.amplitude_use_sites.len());
316        for (use_site, use_site_id) in self
317            .amplitude_use_sites
318            .iter()
319            .zip(&self.amplitude_use_site_ids)
320        {
321            let mapped_index = left_compute_map[use_site.amplitude_index];
322            let new_index = amplitude_use_sites.len();
323            left_map.push(new_index);
324            use_site_id_to_index.insert(*use_site_id, new_index);
325            amplitude_use_site_ids.push(*use_site_id);
326            amplitude_names.push(use_site.tags.display_label());
327            amplitude_use_sites.push(AmplitudeUseSite {
328                amplitude_index: mapped_index,
329                tags: use_site.tags.clone(),
330            });
331        }
332
333        let mut right_map = Vec::with_capacity(other.amplitude_use_sites.len());
334        for (use_site, use_site_id) in other
335            .amplitude_use_sites
336            .iter()
337            .zip(&other.amplitude_use_site_ids)
338        {
339            if let Some(existing) = use_site_id_to_index.get(use_site_id) {
340                right_map.push(*existing);
341                continue;
342            }
343            let mapped_index = right_compute_map[use_site.amplitude_index];
344            let new_index = amplitude_use_sites.len();
345            right_map.push(new_index);
346            use_site_id_to_index.insert(*use_site_id, new_index);
347            amplitude_use_site_ids.push(*use_site_id);
348            amplitude_names.push(use_site.tags.display_label());
349            amplitude_use_sites.push(AmplitudeUseSite {
350                amplitude_index: mapped_index,
351                tags: use_site.tags.clone(),
352            });
353        }
354        let use_site_tags = amplitude_use_sites
355            .iter()
356            .map(|use_site| use_site.tags.clone())
357            .collect::<Vec<_>>();
358        resources.configure_amplitude_tags(&use_site_tags);
359
360        Ok((
361            Self {
362                amplitudes,
363                amplitude_names,
364                amplitude_ids,
365                amplitude_use_sites,
366                amplitude_use_site_ids,
367                resources,
368            },
369            left_map,
370            right_map,
371        ))
372    }
373}
374
375/// Expression tree used by [`Expression`].
376#[allow(missing_docs)]
377#[derive(Clone, Serialize, Deserialize, Default, Debug)]
378pub enum ExpressionNode {
379    #[default]
380    /// A expression equal to zero.
381    Zero,
382    /// A expression equal to one.
383    One,
384    /// A real-valued constant.
385    Constant(f64),
386    /// A complex-valued constant.
387    ComplexConstant(Complex64),
388    /// A registered [`Amplitude`] referenced by index.
389    Amp(usize),
390    /// The sum of two [`ExpressionNode`]s.
391    Add(Box<ExpressionNode>, Box<ExpressionNode>),
392    /// The difference of two [`ExpressionNode`]s.
393    Sub(Box<ExpressionNode>, Box<ExpressionNode>),
394    /// The product of two [`ExpressionNode`]s.
395    Mul(Box<ExpressionNode>, Box<ExpressionNode>),
396    /// The division of two [`ExpressionNode`]s.
397    Div(Box<ExpressionNode>, Box<ExpressionNode>),
398    /// The additive inverse of an [`ExpressionNode`].
399    Neg(Box<ExpressionNode>),
400    /// The real part of an [`ExpressionNode`].
401    Real(Box<ExpressionNode>),
402    /// The imaginary part of an [`ExpressionNode`].
403    Imag(Box<ExpressionNode>),
404    /// The complex conjugate of an [`ExpressionNode`].
405    Conj(Box<ExpressionNode>),
406    /// The absolute square of an [`ExpressionNode`].
407    NormSqr(Box<ExpressionNode>),
408    Sqrt(Box<ExpressionNode>),
409    Pow(Box<ExpressionNode>, Box<ExpressionNode>),
410    PowI(Box<ExpressionNode>, i32),
411    PowF(Box<ExpressionNode>, f64),
412    Exp(Box<ExpressionNode>),
413    Sin(Box<ExpressionNode>),
414    Cos(Box<ExpressionNode>),
415    Log(Box<ExpressionNode>),
416    Cis(Box<ExpressionNode>),
417}
418
419impl ExpressionNode {
420    fn remap(&self, mapping: &[usize]) -> Self {
421        match self {
422            Self::Amp(idx) => Self::Amp(mapping[*idx]),
423            Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
424            Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
425            Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
426            Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
427            Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
428            Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
429            Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
430            Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
431            Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
432            Self::Zero => Self::Zero,
433            Self::One => Self::One,
434            Self::Constant(v) => Self::Constant(*v),
435            Self::ComplexConstant(v) => Self::ComplexConstant(*v),
436            Self::Sqrt(a) => Self::Sqrt(Box::new(a.remap(mapping))),
437            Self::Pow(a, b) => Self::Pow(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
438            Self::PowI(a, power) => Self::PowI(Box::new(a.remap(mapping)), *power),
439            Self::PowF(a, power) => Self::PowF(Box::new(a.remap(mapping)), *power),
440            Self::Exp(a) => Self::Exp(Box::new(a.remap(mapping))),
441            Self::Sin(a) => Self::Sin(Box::new(a.remap(mapping))),
442            Self::Cos(a) => Self::Cos(Box::new(a.remap(mapping))),
443            Self::Log(a) => Self::Log(Box::new(a.remap(mapping))),
444            Self::Cis(a) => Self::Cis(Box::new(a.remap(mapping))),
445        }
446    }
447}
448
449impl From<f64> for Expression {
450    fn from(value: f64) -> Self {
451        if value == 0.0 {
452            Self {
453                registry: ExpressionRegistry::default(),
454                tree: ExpressionNode::Zero,
455            }
456        } else if value == 1.0 {
457            Self {
458                registry: ExpressionRegistry::default(),
459                tree: ExpressionNode::One,
460            }
461        } else {
462            Self {
463                registry: ExpressionRegistry::default(),
464                tree: ExpressionNode::Constant(value),
465            }
466        }
467    }
468}
469impl From<&f64> for Expression {
470    fn from(value: &f64) -> Self {
471        (*value).into()
472    }
473}
474impl From<Complex64> for Expression {
475    fn from(value: Complex64) -> Self {
476        if value == Complex64::ZERO {
477            Self {
478                registry: ExpressionRegistry::default(),
479                tree: ExpressionNode::Zero,
480            }
481        } else if value == Complex64::ONE {
482            Self {
483                registry: ExpressionRegistry::default(),
484                tree: ExpressionNode::One,
485            }
486        } else {
487            Self {
488                registry: ExpressionRegistry::default(),
489                tree: ExpressionNode::ComplexConstant(value),
490            }
491        }
492    }
493}
494impl From<&Complex64> for Expression {
495    fn from(value: &Complex64) -> Self {
496        (*value).into()
497    }
498}
499
500impl Expression {
501    /// Build an [`Expression`] from a single [`Amplitude`].
502    pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
503        let registry = ExpressionRegistry::singleton(amplitude)?;
504        Ok(Self {
505            tree: ExpressionNode::Amp(0),
506            registry,
507        })
508    }
509
510    /// Create an expression representing zero, the additive identity.
511    pub fn zero() -> Self {
512        Self {
513            registry: ExpressionRegistry::default(),
514            tree: ExpressionNode::Zero,
515        }
516    }
517
518    /// Create an expression representing one, the multiplicative identity.
519    pub fn one() -> Self {
520        Self {
521            registry: ExpressionRegistry::default(),
522            tree: ExpressionNode::One,
523        }
524    }
525
526    fn binary_op(
527        a: &Expression,
528        b: &Expression,
529        build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
530    ) -> Expression {
531        let (registry, left_map, right_map) = a
532            .registry
533            .merge(&b.registry)
534            .expect("merging expression registries should not fail");
535        let left_tree = a.tree.remap(&left_map);
536        let right_tree = b.tree.remap(&right_map);
537        Expression {
538            registry,
539            tree: build(Box::new(left_tree), Box::new(right_tree)),
540        }
541    }
542
543    fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
544        Expression {
545            registry: a.registry.clone(),
546            tree: build(Box::new(a.tree.clone())),
547        }
548    }
549
550    /// Get the parameters used by this expression.
551    pub fn parameters(&self) -> ParameterMap {
552        self.registry.resources.parameters()
553    }
554
555    /// Number of free parameters.
556    pub fn n_free(&self) -> usize {
557        self.registry.resources.n_free_parameters()
558    }
559
560    /// Number of fixed parameters.
561    pub fn n_fixed(&self) -> usize {
562        self.registry.resources.n_fixed_parameters()
563    }
564
565    /// Total number of parameters.
566    pub fn n_parameters(&self) -> usize {
567        self.registry.resources.n_parameters()
568    }
569
570    /// Returns a tree-like diagnostic snapshot of this expression's compiled form.
571    ///
572    /// This compiles the expression on each call with every registered amplitude active. Use
573    /// [`Evaluator::compiled_expression`] when you need the compiled form for a loaded evaluator's
574    /// current active-amplitude mask.
575    pub fn compiled_expression(&self) -> CompiledExpression {
576        let active_amplitudes = vec![true; self.registry.amplitude_use_sites.len()];
577        let amplitude_dependencies = self
578            .registry
579            .amplitude_use_sites
580            .iter()
581            .map(|use_site| {
582                ir::DependenceClass::from(
583                    self.registry.amplitudes[use_site.amplitude_index].dependence_hint(),
584                )
585            })
586            .collect::<Vec<_>>();
587        let amplitude_realness = self
588            .registry
589            .amplitude_use_sites
590            .iter()
591            .map(|use_site| self.registry.amplitudes[use_site.amplitude_index].real_valued_hint())
592            .collect::<Vec<_>>();
593        let expression_ir = ir::compile_expression_ir_with_real_hints(
594            &self.tree,
595            &active_amplitudes,
596            &amplitude_dependencies,
597            &amplitude_realness,
598        );
599        CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
600    }
601
602    /// Fix a parameter used by this expression's evaluator resources.
603    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
604        self.registry.resources.fix_parameter(name, value)
605    }
606
607    /// Mark a parameter used by this expression's evaluator resources as free.
608    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
609        self.registry.resources.free_parameter(name)
610    }
611
612    /// Return a new [`Expression`] with a single parameter renamed.
613    pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
614        self.registry.resources.rename_parameter(old, new)
615    }
616
617    /// Return a new [`Expression`] with several parameters renamed.
618    pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
619        self.registry.resources.rename_parameters(mapping)
620    }
621
622    /// Load an [`Expression`] against a dataset, binding amplitudes and reserving caches.
623    pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
624        let mut resources = self.registry.resources.clone();
625        let metadata = dataset.metadata();
626        resources.reserve_cache(dataset.n_events_local());
627        resources.refresh_active_indices();
628        let parameter_map = resources.parameter_map.clone();
629        let mut amplitudes: Vec<Box<dyn Amplitude>> = self
630            .registry
631            .amplitudes
632            .iter()
633            .map(|amp| dyn_clone::clone_box(&**amp))
634            .collect();
635        {
636            for amplitude in amplitudes.iter_mut() {
637                amplitude.bind(metadata)?;
638                amplitude.precompute_all(dataset, &mut resources);
639            }
640        }
641        let ir_compile_start = Instant::now();
642        let expression_ir = {
643            let mut active_amplitudes = vec![false; self.registry.amplitude_use_sites.len()];
644            for &index in resources.active_indices() {
645                active_amplitudes[index] = true;
646            }
647            let amplitude_dependencies = self
648                .registry
649                .amplitude_use_sites
650                .iter()
651                .map(|use_site| {
652                    ir::DependenceClass::from(
653                        amplitudes[use_site.amplitude_index].dependence_hint(),
654                    )
655                })
656                .collect::<Vec<_>>();
657            let amplitude_realness = self
658                .registry
659                .amplitude_use_sites
660                .iter()
661                .map(|use_site| amplitudes[use_site.amplitude_index].real_valued_hint())
662                .collect::<Vec<_>>();
663            ir::compile_expression_ir_with_real_hints(
664                &self.tree,
665                &active_amplitudes,
666                &amplitude_dependencies,
667                &amplitude_realness,
668            )
669        };
670        let initial_ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
671        let cached_integrals_start = Instant::now();
672        let cached_integrals = Evaluator::precompute_cached_integrals_at_load(
673            &expression_ir,
674            &amplitudes,
675            &self.registry.amplitude_use_sites,
676            &resources,
677            dataset,
678            parameter_map.free().len(),
679        )?;
680        let initial_cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
681        let lowering_start = Instant::now();
682        let lowered_artifacts = Arc::new(Evaluator::lower_expression_runtime_artifacts(
683            &expression_ir,
684            &cached_integrals,
685        )?);
686        let initial_lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
687        let execution_sets = expression_ir.normalization_execution_sets().clone();
688        let cached_integral_key =
689            Evaluator::cached_integral_cache_key(resources.active.clone(), dataset);
690        let cached_integral_state = Arc::new(CachedIntegralCacheState {
691            key: cached_integral_key.clone(),
692            expression_ir,
693            values: cached_integrals,
694            execution_sets,
695        });
696        let specialization_state = ExpressionSpecializationState {
697            cached_integrals: cached_integral_state.clone(),
698            lowered_artifacts: lowered_artifacts.clone(),
699        };
700        let specialization_cache = HashMap::from([(cached_integral_key, specialization_state)]);
701        let lowered_artifact_cache =
702            HashMap::from([(resources.active.clone(), lowered_artifacts.clone())]);
703        Ok(Evaluator {
704            amplitudes,
705            amplitude_use_sites: self.registry.amplitude_use_sites.clone(),
706            resources: Arc::new(RwLock::new(resources)),
707            dataset: dataset.clone(),
708            expression: self.tree.clone(),
709            ir_planning: ExpressionIrPlanningState {
710                cached_integrals: Arc::new(RwLock::new(Some(cached_integral_state))),
711                specialization_cache: Arc::new(RwLock::new(specialization_cache)),
712                specialization_metrics: Arc::new(RwLock::new(ExpressionSpecializationMetrics {
713                    cache_hits: 0,
714                    cache_misses: 1,
715                })),
716                lowered_artifact_cache: Arc::new(RwLock::new(lowered_artifact_cache)),
717                active_lowered_artifacts: Arc::new(RwLock::new(Some(lowered_artifacts.clone()))),
718                specialization_status: Arc::new(RwLock::new(Some(
719                    ExpressionSpecializationStatus {
720                        origin: ExpressionSpecializationOrigin::InitialLoad,
721                    },
722                ))),
723                compile_metrics: Arc::new(RwLock::new(ExpressionCompileMetrics {
724                    initial_ir_compile_nanos,
725                    initial_cached_integrals_nanos,
726                    initial_lowering_nanos,
727                    specialization_lowering_cache_misses: 1,
728                    ..Default::default()
729                })),
730            },
731            registry: self.registry.clone(),
732        })
733    }
734
735    /// Takes the real part of the given [`Expression`].
736    pub fn real(&self) -> Self {
737        Self::unary_op(self, ExpressionNode::Real)
738    }
739    /// Takes the imaginary part of the given [`Expression`].
740    pub fn imag(&self) -> Self {
741        Self::unary_op(self, ExpressionNode::Imag)
742    }
743    /// Takes the complex conjugate of the given [`Expression`].
744    pub fn conj(&self) -> Self {
745        Self::unary_op(self, ExpressionNode::Conj)
746    }
747    /// Takes the absolute square of the given [`Expression`].
748    pub fn norm_sqr(&self) -> Self {
749        Self::unary_op(self, ExpressionNode::NormSqr)
750    }
751    /// Takes the square root of the given [`Expression`].
752    pub fn sqrt(&self) -> Self {
753        Self::unary_op(self, ExpressionNode::Sqrt)
754    }
755    /// Raises the given [`Expression`] to an expression-valued power.
756    pub fn pow(&self, power: &Expression) -> Self {
757        Self::binary_op(self, power, ExpressionNode::Pow)
758    }
759    /// Raises the given [`Expression`] to an integer power.
760    pub fn powi(&self, power: i32) -> Self {
761        Self::unary_op(self, |input| ExpressionNode::PowI(input, power))
762    }
763    /// Raises the given [`Expression`] to a real-valued power.
764    pub fn powf(&self, power: f64) -> Self {
765        Self::unary_op(self, |input| ExpressionNode::PowF(input, power))
766    }
767    /// Takes the exponential of the given [`Expression`].
768    pub fn exp(&self) -> Self {
769        Self::unary_op(self, ExpressionNode::Exp)
770    }
771    /// Takes the sine of the given [`Expression`].
772    pub fn sin(&self) -> Self {
773        Self::unary_op(self, ExpressionNode::Sin)
774    }
775    /// Takes the cosine of the given [`Expression`].
776    pub fn cos(&self) -> Self {
777        Self::unary_op(self, ExpressionNode::Cos)
778    }
779    /// Takes the natural logarithm of the given [`Expression`].
780    pub fn log(&self) -> Self {
781        Self::unary_op(self, ExpressionNode::Log)
782    }
783    /// Takes the complex phase factor exp(i * expression).
784    pub fn cis(&self) -> Self {
785        Self::unary_op(self, ExpressionNode::Cis)
786    }
787
788    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
789    fn write_tree(
790        &self,
791        t: &ExpressionNode,
792        f: &mut std::fmt::Formatter<'_>,
793        parent_prefix: &str,
794        immediate_prefix: &str,
795        parent_suffix: &str,
796    ) -> std::fmt::Result {
797        let display_string = match t {
798            ExpressionNode::Amp(idx) => {
799                let name = self
800                    .registry
801                    .amplitude_names
802                    .get(*idx)
803                    .cloned()
804                    .unwrap_or_else(|| "<unregistered>".to_string());
805                format!("{name}(id={idx})")
806            }
807            ExpressionNode::Add(_, _) => "+".to_string(),
808            ExpressionNode::Sub(_, _) => "-".to_string(),
809            ExpressionNode::Mul(_, _) => "×".to_string(),
810            ExpressionNode::Div(_, _) => "÷".to_string(),
811            ExpressionNode::Neg(_) => "-".to_string(),
812            ExpressionNode::Real(_) => "Re".to_string(),
813            ExpressionNode::Imag(_) => "Im".to_string(),
814            ExpressionNode::Conj(_) => "*".to_string(),
815            ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
816            ExpressionNode::Zero => "0 (exact)".to_string(),
817            ExpressionNode::One => "1 (exact)".to_string(),
818            ExpressionNode::Constant(v) => v.to_string(),
819            ExpressionNode::ComplexConstant(v) => v.to_string(),
820            ExpressionNode::Sqrt(_) => "Sqrt".to_string(),
821            ExpressionNode::Pow(_, _) => "Pow".to_string(),
822            ExpressionNode::PowI(_, power) => format!("PowI({power})"),
823            ExpressionNode::PowF(_, power) => format!("PowF({power})"),
824            ExpressionNode::Exp(_) => "Exp".to_string(),
825            ExpressionNode::Sin(_) => "Sin".to_string(),
826            ExpressionNode::Cos(_) => "Cos".to_string(),
827            ExpressionNode::Log(_) => "Log".to_string(),
828            ExpressionNode::Cis(_) => "Cis".to_string(),
829        };
830        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
831        match t {
832            ExpressionNode::Amp(_)
833            | ExpressionNode::Zero
834            | ExpressionNode::One
835            | ExpressionNode::Constant(_)
836            | ExpressionNode::ComplexConstant(_) => {}
837            ExpressionNode::Add(a, b)
838            | ExpressionNode::Sub(a, b)
839            | ExpressionNode::Mul(a, b)
840            | ExpressionNode::Div(a, b)
841            | ExpressionNode::Pow(a, b) => {
842                let terms = [a, b];
843                let mut it = terms.iter().peekable();
844                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
845                while let Some(child) = it.next() {
846                    match it.peek() {
847                        Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│  "),
848                        None => self.write_tree(child, f, &child_prefix, "└─ ", "   "),
849                    }?;
850                }
851            }
852            ExpressionNode::Neg(a)
853            | ExpressionNode::Real(a)
854            | ExpressionNode::Imag(a)
855            | ExpressionNode::Conj(a)
856            | ExpressionNode::NormSqr(a)
857            | ExpressionNode::Sqrt(a)
858            | ExpressionNode::PowI(a, _)
859            | ExpressionNode::PowF(a, _)
860            | ExpressionNode::Exp(a)
861            | ExpressionNode::Sin(a)
862            | ExpressionNode::Cos(a)
863            | ExpressionNode::Log(a)
864            | ExpressionNode::Cis(a) => {
865                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
866                self.write_tree(a, f, &child_prefix, "└─ ", "   ")?;
867            }
868        }
869        Ok(())
870    }
871}
872
873impl Debug for Expression {
874    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875        self.write_tree(&self.tree, f, "", "", "")
876    }
877}
878
879impl Display for Expression {
880    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
881        self.write_tree(&self.tree, f, "", "", "")
882    }
883}
884
885#[rustfmt::skip]
886impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
887    Expression::binary_op(a, b, ExpressionNode::Add)
888});
889#[rustfmt::skip]
890impl_op_ex!(+ |a: &Expression, b: &f64| -> Expression {
891    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
892});
893#[rustfmt::skip]
894impl_op_ex!(+ |a: &f64, b: &Expression| -> Expression {
895    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
896});
897#[rustfmt::skip]
898impl_op_ex!(+ |a: &Expression, b: &Complex64| -> Expression {
899    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
900});
901#[rustfmt::skip]
902impl_op_ex!(+ |a: &Complex64, b: &Expression| -> Expression {
903    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
904});
905
906#[rustfmt::skip]
907impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
908    Expression::binary_op(a, b, ExpressionNode::Sub)
909});
910#[rustfmt::skip]
911impl_op_ex!(- |a: &Expression, b: &f64| -> Expression {
912    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
913});
914#[rustfmt::skip]
915impl_op_ex!(- |a: &f64, b: &Expression| -> Expression {
916    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
917});
918#[rustfmt::skip]
919impl_op_ex!(- |a: &Expression, b: &Complex64| -> Expression {
920    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
921});
922#[rustfmt::skip]
923impl_op_ex!(- |a: &Complex64, b: &Expression| -> Expression {
924    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
925});
926
927#[rustfmt::skip]
928impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
929    Expression::binary_op(a, b, ExpressionNode::Mul)
930});
931#[rustfmt::skip]
932impl_op_ex!(* |a: &Expression, b: &f64| -> Expression {
933    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
934});
935#[rustfmt::skip]
936impl_op_ex!(* |a: &f64, b: &Expression| -> Expression {
937    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
938});
939#[rustfmt::skip]
940impl_op_ex!(* |a: &Expression, b: &Complex64| -> Expression {
941    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
942});
943#[rustfmt::skip]
944impl_op_ex!(* |a: &Complex64, b: &Expression| -> Expression {
945    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
946});
947
948#[rustfmt::skip]
949impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
950    Expression::binary_op(a, b, ExpressionNode::Div)
951});
952#[rustfmt::skip]
953impl_op_ex!(/ |a: &Expression, b: &f64| -> Expression {
954    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
955});
956#[rustfmt::skip]
957impl_op_ex!(/ |a: &f64, b: &Expression| -> Expression {
958    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
959});
960#[rustfmt::skip]
961impl_op_ex!(/ |a: &Expression, b: &Complex64| -> Expression {
962    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
963});
964#[rustfmt::skip]
965impl_op_ex!(/ |a: &Complex64, b: &Expression| -> Expression {
966    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
967});
968
969#[rustfmt::skip]
970impl_op_ex!(- |a: &Expression| -> Expression {
971    Expression::unary_op(a, ExpressionNode::Neg)
972});
973// NOTE: no need to add an impl for negating f64 or complex!
974
975#[derive(Clone, Debug)]
976#[doc(hidden)]
977pub struct ExpressionValueProgramSnapshot {
978    lowered_program: lowered::LoweredProgram,
979}
980
981#[derive(Clone, Debug, PartialEq)]
982/// A node in a compiled expression diagnostic snapshot.
983pub enum CompiledExpressionNode {
984    /// A complex constant.
985    Constant(Complex64),
986    /// A registered amplitude use-site by index and display label.
987    Amplitude {
988        /// The amplitude index used by the compiled evaluator.
989        index: usize,
990        /// The registered amplitude display label.
991        name: String,
992    },
993    /// A unary operation and its input node.
994    Unary {
995        /// The display label for the operation.
996        op: String,
997        /// The input node index.
998        input: usize,
999    },
1000    /// A binary operation and its input nodes.
1001    Binary {
1002        /// The display label for the operation.
1003        op: String,
1004        /// The left input node index.
1005        left: usize,
1006        /// The right input node index.
1007        right: usize,
1008    },
1009}
1010
1011#[derive(Clone, Debug, PartialEq)]
1012/// Tree-like diagnostic view of the compiled expression DAG.
1013///
1014/// The compiled expression is a directed acyclic graph because common subexpressions can be
1015/// deduplicated during compilation. The display format expands the graph from the root once and
1016/// marks later visits to the same node with `(ref)`.
1017pub struct CompiledExpression {
1018    nodes: Vec<CompiledExpressionNode>,
1019    root: usize,
1020}
1021
1022impl CompiledExpression {
1023    fn from_ir(ir: &ir::ExpressionIR, amplitude_names: &[String]) -> Self {
1024        let nodes = ir
1025            .nodes()
1026            .iter()
1027            .map(|node| match node {
1028                ir::IrNode::Constant(value) => CompiledExpressionNode::Constant(*value),
1029                ir::IrNode::Amp(index) => CompiledExpressionNode::Amplitude {
1030                    index: *index,
1031                    name: amplitude_names
1032                        .get(*index)
1033                        .cloned()
1034                        .unwrap_or_else(|| "<unregistered>".to_string()),
1035                },
1036                ir::IrNode::Unary { op, input } => CompiledExpressionNode::Unary {
1037                    op: compiled_unary_op_label(*op),
1038                    input: *input,
1039                },
1040                ir::IrNode::Binary { op, left, right } => CompiledExpressionNode::Binary {
1041                    op: compiled_binary_op_label(*op),
1042                    left: *left,
1043                    right: *right,
1044                },
1045            })
1046            .collect();
1047        Self {
1048            nodes,
1049            root: ir.root(),
1050        }
1051    }
1052
1053    /// Returns the compiled expression node list in evaluator execution order.
1054    pub fn nodes(&self) -> &[CompiledExpressionNode] {
1055        &self.nodes
1056    }
1057
1058    /// Returns the root node index.
1059    pub fn root(&self) -> usize {
1060        self.root
1061    }
1062
1063    fn node_label(&self, index: usize) -> String {
1064        let Some(node) = self.nodes.get(index) else {
1065            return format!("#{index} <missing>");
1066        };
1067        let label = match node {
1068            CompiledExpressionNode::Constant(value) => format!("const {value}"),
1069            CompiledExpressionNode::Amplitude { index, name } => {
1070                format!("{name}(id={index})")
1071            }
1072            CompiledExpressionNode::Unary { op, .. }
1073            | CompiledExpressionNode::Binary { op, .. } => op.clone(),
1074        };
1075        format!("#{index} {label}")
1076    }
1077
1078    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
1079    fn write_tree(
1080        &self,
1081        index: usize,
1082        f: &mut std::fmt::Formatter<'_>,
1083        parent_prefix: &str,
1084        immediate_prefix: &str,
1085        parent_suffix: &str,
1086        expanded: &mut [bool],
1087    ) -> std::fmt::Result {
1088        let already_expanded = expanded.get(index).copied().unwrap_or(false);
1089        if let Some(slot) = expanded.get_mut(index) {
1090            *slot = true;
1091        }
1092        let ref_suffix = if already_expanded { " (ref)" } else { "" };
1093        writeln!(
1094            f,
1095            "{}{}{}{}",
1096            parent_prefix,
1097            immediate_prefix,
1098            self.node_label(index),
1099            ref_suffix
1100        )?;
1101        if already_expanded {
1102            return Ok(());
1103        }
1104        let Some(node) = self.nodes.get(index) else {
1105            return Ok(());
1106        };
1107        match node {
1108            CompiledExpressionNode::Constant(_) | CompiledExpressionNode::Amplitude { .. } => {}
1109            CompiledExpressionNode::Unary { input, .. } => {
1110                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1111                self.write_tree(*input, f, &child_prefix, "└─ ", "   ", expanded)?;
1112            }
1113            CompiledExpressionNode::Binary { left, right, .. } => {
1114                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1115                self.write_tree(*left, f, &child_prefix, "├─ ", "│  ", expanded)?;
1116                self.write_tree(*right, f, &child_prefix, "└─ ", "   ", expanded)?;
1117            }
1118        }
1119        Ok(())
1120    }
1121}
1122
1123impl Display for CompiledExpression {
1124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1125        if self.nodes.is_empty() {
1126            return writeln!(f, "<empty>");
1127        }
1128        let mut expanded = vec![false; self.nodes.len()];
1129        self.write_tree(self.root, f, "", "", "", &mut expanded)
1130    }
1131}
1132
1133fn compiled_unary_op_label(op: ir::IrUnaryOp) -> String {
1134    match op {
1135        ir::IrUnaryOp::Neg => "-".to_string(),
1136        ir::IrUnaryOp::Real => "Re".to_string(),
1137        ir::IrUnaryOp::Imag => "Im".to_string(),
1138        ir::IrUnaryOp::Conj => "*".to_string(),
1139        ir::IrUnaryOp::NormSqr => "NormSqr".to_string(),
1140        ir::IrUnaryOp::Sqrt => "Sqrt".to_string(),
1141        ir::IrUnaryOp::PowI(power) => format!("PowI({power})"),
1142        ir::IrUnaryOp::PowF(bits) => format!("PowF({})", f64::from_bits(bits)),
1143        ir::IrUnaryOp::Exp => "Exp".to_string(),
1144        ir::IrUnaryOp::Sin => "Sin".to_string(),
1145        ir::IrUnaryOp::Cos => "Cos".to_string(),
1146        ir::IrUnaryOp::Log => "Log".to_string(),
1147        ir::IrUnaryOp::Cis => "Cis".to_string(),
1148    }
1149}
1150
1151fn compiled_binary_op_label(op: ir::IrBinaryOp) -> String {
1152    match op {
1153        ir::IrBinaryOp::Add => "+".to_string(),
1154        ir::IrBinaryOp::Sub => "-".to_string(),
1155        ir::IrBinaryOp::Mul => "×".to_string(),
1156        ir::IrBinaryOp::Div => "÷".to_string(),
1157        ir::IrBinaryOp::Pow => "Pow".to_string(),
1158    }
1159}
1160#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1161/// Origin of the currently installed expression specialization.
1162pub enum ExpressionSpecializationOrigin {
1163    /// The specialization installed during evaluator construction.
1164    InitialLoad,
1165    /// The specialization was rebuilt because no cached entry matched the current state.
1166    CacheMissRebuild,
1167    /// The specialization was restored from an existing cache entry.
1168    CacheHitRestore,
1169}
1170#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1171/// Current specialization status for the evaluator runtime.
1172pub struct ExpressionSpecializationStatus {
1173    /// How the active specialization was obtained most recently.
1174    pub origin: ExpressionSpecializationOrigin,
1175}
1176#[derive(Clone, Debug, PartialEq, Eq)]
1177/// Diagnostic snapshot of the active runtime specialization state.
1178pub struct ExpressionRuntimeDiagnostics {
1179    /// Whether IR planning state is present for the evaluator.
1180    pub ir_planning_enabled: bool,
1181    /// Whether a lowered value-only program is available for the active specialization.
1182    pub lowered_value_program_present: bool,
1183    /// Whether a lowered gradient-only program is available for the active specialization.
1184    pub lowered_gradient_program_present: bool,
1185    /// Whether a lowered fused value+gradient program is available for the active specialization.
1186    pub lowered_value_gradient_program_present: bool,
1187    /// Number of cached parameter-factor descriptors in the active specialization.
1188    pub cached_parameter_factor_count: usize,
1189    /// Number of cached parameter factors with lowered runtimes available.
1190    pub lowered_cached_parameter_factor_count: usize,
1191    /// Whether a lowered residual normalization runtime is available.
1192    pub residual_runtime_present: bool,
1193    /// Number of cached specialization entries currently retained.
1194    pub specialization_cache_entries: usize,
1195    /// Number of cached lowered-artifact entries currently retained.
1196    pub lowered_artifact_cache_entries: usize,
1197    /// Origin of the currently installed specialization, when available.
1198    pub specialization_status: Option<ExpressionSpecializationStatus>,
1199}
1200#[derive(Clone)]
1201/// IR-planning state derived from the semantic expression tree plus the current active mask.
1202///
1203/// Invariants:
1204/// - `expression_ir` is never a source of truth; it is always derived from `Evaluator::expression`.
1205/// - `cached_integrals` are specialization-dependent and must be treated as invalid once the
1206///   active mask or dataset identity changes.
1207struct ExpressionIrPlanningState {
1208    cached_integrals: Arc<RwLock<Option<Arc<CachedIntegralCacheState>>>>,
1209    specialization_cache:
1210        Arc<RwLock<HashMap<CachedIntegralCacheKey, ExpressionSpecializationState>>>,
1211    specialization_metrics: Arc<RwLock<ExpressionSpecializationMetrics>>,
1212    lowered_artifact_cache: Arc<RwLock<HashMap<Vec<bool>, Arc<LoweredArtifactCacheState>>>>,
1213    active_lowered_artifacts: Arc<RwLock<Option<Arc<LoweredArtifactCacheState>>>>,
1214    specialization_status: Arc<RwLock<Option<ExpressionSpecializationStatus>>>,
1215    compile_metrics: Arc<RwLock<ExpressionCompileMetrics>>,
1216}
1217/// Evaluator for [`Expression`] that mirrors the existing evaluator behavior.
1218#[allow(missing_docs)]
1219#[derive(Clone)]
1220pub struct Evaluator {
1221    pub amplitudes: Vec<Box<dyn Amplitude>>,
1222    amplitude_use_sites: Vec<AmplitudeUseSite>,
1223    pub resources: Arc<RwLock<Resources>>,
1224    pub dataset: Arc<Dataset>,
1225    pub expression: ExpressionNode,
1226    ir_planning: ExpressionIrPlanningState,
1227    registry: ExpressionRegistry,
1228}
1229
1230#[allow(missing_docs)]
1231impl Evaluator {
1232    /// Internal benchmarking/debug counters for specialization cache reuse.
1233    pub fn expression_specialization_metrics(&self) -> ExpressionSpecializationMetrics {
1234        *self.ir_planning.specialization_metrics.read()
1235    }
1236    /// Reset specialization cache counters while leaving cached specializations intact.
1237    pub fn reset_expression_specialization_metrics(&self) {
1238        *self.ir_planning.specialization_metrics.write() =
1239            ExpressionSpecializationMetrics::default();
1240    }
1241    /// Internal benchmarking/debug metrics for staged IR compile and lowering costs.
1242    pub fn expression_compile_metrics(&self) -> ExpressionCompileMetrics {
1243        *self.ir_planning.compile_metrics.read()
1244    }
1245    /// Internal diagnostics surface for active runtime specialization state.
1246    pub fn expression_runtime_diagnostics(&self) -> ExpressionRuntimeDiagnostics {
1247        let active_artifacts = self.active_lowered_artifacts();
1248        let cached_parameter_factor_count = self
1249            .ir_planning
1250            .cached_integrals
1251            .read()
1252            .as_ref()
1253            .map(|state| state.values.len())
1254            .unwrap_or(0);
1255        let lowered_cached_parameter_factor_count = active_artifacts
1256            .as_ref()
1257            .map(|artifacts| {
1258                artifacts
1259                    .lowered_parameter_factors
1260                    .iter()
1261                    .filter(|factor| factor.is_some())
1262                    .count()
1263            })
1264            .unwrap_or(0);
1265        let residual_runtime_present = active_artifacts
1266            .as_ref()
1267            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
1268            .is_some();
1269        ExpressionRuntimeDiagnostics {
1270            ir_planning_enabled: true,
1271            lowered_value_program_present: true,
1272            lowered_gradient_program_present: true,
1273            lowered_value_gradient_program_present: true,
1274            cached_parameter_factor_count,
1275            lowered_cached_parameter_factor_count,
1276            residual_runtime_present,
1277            specialization_cache_entries: self.ir_planning.specialization_cache.read().len(),
1278            lowered_artifact_cache_entries: self.ir_planning.lowered_artifact_cache.read().len(),
1279            specialization_status: *self.ir_planning.specialization_status.read(),
1280        }
1281    }
1282    /// Reset post-load compile/lowering counters while preserving initial-load metrics.
1283    pub fn reset_expression_compile_metrics(&self) {
1284        let mut metrics = self.ir_planning.compile_metrics.write();
1285        metrics.specialization_cache_hits = 0;
1286        metrics.specialization_cache_misses = 0;
1287        metrics.specialization_ir_compile_nanos = 0;
1288        metrics.specialization_cached_integrals_nanos = 0;
1289        metrics.specialization_lowering_nanos = 0;
1290        metrics.specialization_lowering_cache_hits = 0;
1291        metrics.specialization_lowering_cache_misses = 0;
1292        metrics.specialization_cache_restore_nanos = 0;
1293    }
1294    #[cfg(test)]
1295    fn expression_ir(&self) -> ir::ExpressionIR {
1296        self.ir_planning
1297            .cached_integrals
1298            .read()
1299            .as_ref()
1300            .map(|state| state.expression_ir.clone())
1301            .expect("cached integral state should exist for evaluator IR access")
1302    }
1303    fn lowered_runtime(&self) -> lowered::LoweredExpressionRuntime {
1304        self.active_lowered_artifacts()
1305            .expect("active lowered artifacts should exist for the current specialization")
1306            .lowered_runtime
1307            .clone()
1308    }
1309    fn active_lowered_artifacts(&self) -> Option<Arc<LoweredArtifactCacheState>> {
1310        self.ir_planning.active_lowered_artifacts.read().clone()
1311    }
1312    fn lowered_runtime_slot_count(&self) -> usize {
1313        let runtime = self.lowered_runtime();
1314        [
1315            runtime.value_program().scratch_slots(),
1316            runtime.gradient_program().scratch_slots(),
1317            runtime.value_gradient_program().scratch_slots(),
1318        ]
1319        .into_iter()
1320        .max()
1321        .unwrap_or(0)
1322    }
1323    fn lowered_value_runtime_slot_count(&self) -> usize {
1324        self.lowered_runtime().value_program().scratch_slots()
1325    }
1326
1327    #[doc(hidden)]
1328    pub fn expression_value_program_snapshot(&self) -> ExpressionValueProgramSnapshot {
1329        ExpressionValueProgramSnapshot {
1330            lowered_program: self.lowered_runtime().value_program().clone(),
1331        }
1332    }
1333
1334    #[doc(hidden)]
1335    pub fn expression_value_program_snapshot_for_active_mask(
1336        &self,
1337        active_mask: &[bool],
1338    ) -> LadduResult<ExpressionValueProgramSnapshot> {
1339        let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1340        let lowered_program =
1341            lowered::LoweredProgram::from_ir_value_only(&expression_ir).map_err(|error| {
1342                LadduError::Custom(format!(
1343                    "Failed to lower value-only active-mask runtime: {error:?}"
1344                ))
1345            })?;
1346        Ok(ExpressionValueProgramSnapshot { lowered_program })
1347    }
1348
1349    #[doc(hidden)]
1350    pub fn expression_value_program_snapshot_slot_count(
1351        &self,
1352        snapshot: &ExpressionValueProgramSnapshot,
1353    ) -> usize {
1354        let _ = self;
1355        snapshot.lowered_program.scratch_slots()
1356    }
1357
1358    /// Returns a tree-like diagnostic snapshot of the compiled expression for the evaluator's
1359    /// current active-amplitude mask.
1360    pub fn compiled_expression(&self) -> CompiledExpression {
1361        let expression_ir = self.compile_expression_ir_for_active_mask(&self.active_mask());
1362        CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
1363    }
1364
1365    /// Returns the expression represented by this evaluator.
1366    pub fn expression(&self) -> Expression {
1367        Expression {
1368            tree: self.expression.clone(),
1369            registry: self.registry.clone(),
1370        }
1371    }
1372    fn lowered_gradient_runtime_slot_count(&self) -> usize {
1373        self.lowered_runtime().gradient_program().scratch_slots()
1374    }
1375    fn lowered_value_gradient_runtime_slot_count(&self) -> usize {
1376        self.lowered_runtime()
1377            .value_gradient_program()
1378            .scratch_slots()
1379    }
1380
1381    fn expression_value_slot_count(&self) -> usize {
1382        self.lowered_value_runtime_slot_count()
1383    }
1384    fn expression_gradient_slot_count(&self) -> usize {
1385        self.lowered_gradient_runtime_slot_count()
1386    }
1387    fn expression_value_gradient_slot_count(&self) -> usize {
1388        self.lowered_value_gradient_runtime_slot_count()
1389    }
1390
1391    #[doc(hidden)]
1392    pub fn expression_value_gradient_slot_count_public(&self) -> usize {
1393        self.expression_value_gradient_slot_count()
1394    }
1395    #[cfg(test)]
1396    fn specialization_cache_len(&self) -> usize {
1397        self.ir_planning.specialization_cache.read().len()
1398    }
1399    #[cfg(test)]
1400    fn lowered_artifact_cache_len(&self) -> usize {
1401        self.ir_planning.lowered_artifact_cache.read().len()
1402    }
1403    fn install_expression_specialization(&self, specialization: &ExpressionSpecializationState) {
1404        debug_assert!(Self::lowered_artifact_signature_matches(
1405            &specialization.lowered_artifacts,
1406            &specialization.cached_integrals.values,
1407        ));
1408        *self.ir_planning.cached_integrals.write() = Some(specialization.cached_integrals.clone());
1409        *self.ir_planning.active_lowered_artifacts.write() =
1410            Some(specialization.lowered_artifacts.clone());
1411        debug_assert_eq!(
1412            self.active_lowered_artifacts()
1413                .as_ref()
1414                .map(|artifacts| Arc::ptr_eq(artifacts, &specialization.lowered_artifacts)),
1415            Some(true)
1416        );
1417        debug_assert_eq!(
1418            self.lowered_runtime().value_program().scratch_slots(),
1419            specialization
1420                .lowered_artifacts
1421                .lowered_runtime
1422                .value_program()
1423                .scratch_slots()
1424        );
1425    }
1426    fn lower_expression_runtime_artifacts(
1427        expression_ir: &ir::ExpressionIR,
1428        values: &[PrecomputedCachedIntegral],
1429    ) -> LadduResult<LoweredArtifactCacheState> {
1430        let parameter_node_indices = values
1431            .iter()
1432            .map(|value| value.parameter_node_index)
1433            .collect();
1434        let mul_node_indices = values.iter().map(|value| value.mul_node_index).collect();
1435        let lowered_parameter_factors = Self::lower_cached_parameter_factors(expression_ir);
1436        let residual_runtime = Self::lower_residual_runtime(expression_ir, values);
1437        let lowered_runtime = lowered::LoweredExpressionRuntime::from_ir_value_gradient(
1438            expression_ir,
1439        )
1440        .map_err(|error| {
1441            LadduError::Custom(format!(
1442                "Failed to lower expression runtime for specialized IR: {error:?}"
1443            ))
1444        })?;
1445        Ok(LoweredArtifactCacheState {
1446            parameter_node_indices,
1447            mul_node_indices,
1448            lowered_parameter_factors,
1449            residual_runtime,
1450            lowered_runtime,
1451        })
1452    }
1453    fn lowered_artifact_signature_matches(
1454        artifacts: &LoweredArtifactCacheState,
1455        values: &[PrecomputedCachedIntegral],
1456    ) -> bool {
1457        artifacts.parameter_node_indices.len() == values.len()
1458            && artifacts.mul_node_indices.len() == values.len()
1459            && artifacts
1460                .parameter_node_indices
1461                .iter()
1462                .copied()
1463                .eq(values.iter().map(|value| value.parameter_node_index))
1464            && artifacts
1465                .mul_node_indices
1466                .iter()
1467                .copied()
1468                .eq(values.iter().map(|value| value.mul_node_index))
1469    }
1470    fn build_expression_specialization(
1471        &self,
1472        resources: &Resources,
1473        key: CachedIntegralCacheKey,
1474    ) -> LadduResult<ExpressionSpecializationState> {
1475        let ir_compile_start = Instant::now();
1476        let expression_ir = self.compile_expression_ir_for_active_mask(&resources.active);
1477        let ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
1478        let cached_integrals_start = Instant::now();
1479        let values = Self::precompute_cached_integrals_at_load(
1480            &expression_ir,
1481            &self.amplitudes,
1482            &self.amplitude_use_sites,
1483            resources,
1484            &self.dataset,
1485            self.resources.read().n_free_parameters(),
1486        )?;
1487        let cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
1488        let execution_sets = expression_ir.normalization_execution_sets().clone();
1489        let active_mask_key = resources.active.clone();
1490        let cached_lowered_artifacts = {
1491            let lowered_artifact_cache = self.ir_planning.lowered_artifact_cache.read();
1492            lowered_artifact_cache
1493                .get(&active_mask_key)
1494                .cloned()
1495                .filter(|artifacts| Self::lowered_artifact_signature_matches(artifacts, &values))
1496        };
1497        let lowered_artifacts = if let Some(artifacts) = cached_lowered_artifacts {
1498            self.ir_planning
1499                .compile_metrics
1500                .write()
1501                .specialization_lowering_cache_hits += 1;
1502            artifacts
1503        } else {
1504            let lowering_start = Instant::now();
1505            let artifacts = Arc::new(
1506                Self::lower_expression_runtime_artifacts(&expression_ir, &values)
1507                    .expect("specialized lowered runtime should build"),
1508            );
1509            let lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
1510            self.ir_planning
1511                .lowered_artifact_cache
1512                .write()
1513                .insert(active_mask_key, artifacts.clone());
1514            let mut compile_metrics = self.ir_planning.compile_metrics.write();
1515            compile_metrics.specialization_lowering_cache_misses += 1;
1516            compile_metrics.specialization_lowering_nanos += lowering_nanos;
1517            artifacts
1518        };
1519        let mut compile_metrics = self.ir_planning.compile_metrics.write();
1520        compile_metrics.specialization_cache_misses += 1;
1521        compile_metrics.specialization_ir_compile_nanos += ir_compile_nanos;
1522        compile_metrics.specialization_cached_integrals_nanos += cached_integrals_nanos;
1523        Ok(ExpressionSpecializationState {
1524            cached_integrals: Arc::new(CachedIntegralCacheState {
1525                key,
1526                expression_ir,
1527                values,
1528                execution_sets,
1529            }),
1530            lowered_artifacts,
1531        })
1532    }
1533    fn ensure_expression_specialization(
1534        &self,
1535        resources: &Resources,
1536    ) -> LadduResult<ExpressionSpecializationState> {
1537        let key = Self::cached_integral_cache_key(resources.active.clone(), &self.dataset);
1538        if let Some(state) = self.ir_planning.cached_integrals.read().as_ref() {
1539            if state.key == key {
1540                return Ok(ExpressionSpecializationState {
1541                    cached_integrals: state.clone(),
1542                    lowered_artifacts: self
1543                        .active_lowered_artifacts()
1544                        .expect("active lowered artifacts should exist for cached specialization"),
1545                });
1546            }
1547        }
1548        let cached_specialization = {
1549            let specialization_cache = self.ir_planning.specialization_cache.read();
1550            specialization_cache.get(&key).cloned()
1551        };
1552        if let Some(specialization) = cached_specialization {
1553            let restore_start = Instant::now();
1554            self.ir_planning.specialization_metrics.write().cache_hits += 1;
1555            self.install_expression_specialization(&specialization);
1556            *self.ir_planning.specialization_status.write() =
1557                Some(ExpressionSpecializationStatus {
1558                    origin: ExpressionSpecializationOrigin::CacheHitRestore,
1559                });
1560            let restore_nanos = restore_start.elapsed().as_nanos() as u64;
1561            let mut compile_metrics = self.ir_planning.compile_metrics.write();
1562            compile_metrics.specialization_cache_hits += 1;
1563            compile_metrics.specialization_cache_restore_nanos += restore_nanos;
1564            return Ok(specialization);
1565        }
1566        let specialization = self.build_expression_specialization(resources, key.clone())?;
1567        self.ir_planning.specialization_metrics.write().cache_misses += 1;
1568        self.ir_planning
1569            .specialization_cache
1570            .write()
1571            .insert(key, specialization.clone());
1572        self.install_expression_specialization(&specialization);
1573        let origin = if self.ir_planning.specialization_cache.read().len() == 1 {
1574            ExpressionSpecializationOrigin::InitialLoad
1575        } else {
1576            ExpressionSpecializationOrigin::CacheMissRebuild
1577        };
1578        *self.ir_planning.specialization_status.write() =
1579            Some(ExpressionSpecializationStatus { origin });
1580        Ok(specialization)
1581    }
1582    fn rebuild_runtime_specializations(&self, resources: &Resources) {
1583        let _ = self.ensure_expression_specialization(resources);
1584    }
1585    fn refresh_runtime_specializations(&self) {
1586        let resources = self.resources.read();
1587        self.rebuild_runtime_specializations(&resources);
1588    }
1589    fn cached_integral_cache_key(
1590        active_mask: Vec<bool>,
1591        dataset: &Dataset,
1592    ) -> CachedIntegralCacheKey {
1593        let (weights_ptr, weights_local_len) = dataset.local_weight_cache_key();
1594        CachedIntegralCacheKey {
1595            active_mask,
1596            n_events_local: dataset.n_events_local(),
1597            weights_local_len,
1598            weighted_sum_bits: dataset.n_events_weighted_local().to_bits(),
1599            weights_ptr,
1600        }
1601    }
1602    fn precompute_cached_integrals_at_load(
1603        expression_ir: &ir::ExpressionIR,
1604        amplitudes: &[Box<dyn Amplitude>],
1605        amplitude_use_sites: &[AmplitudeUseSite],
1606        resources: &Resources,
1607        dataset: &Dataset,
1608        n_free_parameters: usize,
1609    ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
1610        let descriptors = expression_ir.cached_integral_descriptors();
1611        if descriptors.is_empty() {
1612            return Ok(Vec::new());
1613        }
1614        let execution_sets = expression_ir.normalization_execution_sets();
1615        let seed_parameters = vec![0.0; n_free_parameters];
1616        let parameters = resources.parameter_map.assemble(&seed_parameters)?;
1617        let mut amplitude_values = vec![Complex64::ZERO; amplitude_use_sites.len()];
1618        let mut compute_values = vec![Complex64::ZERO; amplitudes.len()];
1619        let mut value_slots = vec![Complex64::ZERO; expression_ir.node_count()];
1620        let active_set = resources.active_indices();
1621        let cache_active_indices = execution_sets
1622            .cached_cache_amplitudes
1623            .iter()
1624            .copied()
1625            .filter(|index| active_set.binary_search(index).is_ok())
1626            .collect::<Vec<_>>();
1627        let mut weighted_cache_sums = vec![Complex64::ZERO; descriptors.len()];
1628        for (cache, event) in resources.caches.iter().zip(dataset.weights_local().iter()) {
1629            amplitude_values.fill(Complex64::ZERO);
1630            compute_values.fill(Complex64::ZERO);
1631            let mut computed = vec![false; amplitudes.len()];
1632            for &use_site_idx in &cache_active_indices {
1633                let amp_idx = amplitude_use_sites[use_site_idx].amplitude_index;
1634                if !computed[amp_idx] {
1635                    compute_values[amp_idx] = amplitudes[amp_idx].compute(&parameters, cache);
1636                    computed[amp_idx] = true;
1637                }
1638                amplitude_values[use_site_idx] = compute_values[amp_idx];
1639            }
1640            expression_ir.evaluate_into(&amplitude_values, &mut value_slots);
1641            let weight = *event;
1642            for (descriptor_index, descriptor) in descriptors.iter().enumerate() {
1643                weighted_cache_sums[descriptor_index] +=
1644                    value_slots[descriptor.cache_node_index] * weight;
1645            }
1646        }
1647        Ok(descriptors
1648            .iter()
1649            .zip(weighted_cache_sums)
1650            .map(
1651                |(descriptor, weighted_cache_sum)| PrecomputedCachedIntegral {
1652                    mul_node_index: descriptor.mul_node_index,
1653                    parameter_node_index: descriptor.parameter_node_index,
1654                    cache_node_index: descriptor.cache_node_index,
1655                    coefficient: descriptor.coefficient,
1656                    weighted_cache_sum,
1657                },
1658            )
1659            .collect())
1660    }
1661    fn lower_cached_parameter_factors(
1662        expression_ir: &ir::ExpressionIR,
1663    ) -> Vec<Option<lowered::LoweredFactorRuntime>> {
1664        expression_ir
1665            .cached_integral_descriptors()
1666            .iter()
1667            .map(|descriptor| {
1668                lowered::LoweredFactorRuntime::from_ir_root_value_gradient(
1669                    expression_ir,
1670                    descriptor.parameter_node_index,
1671                )
1672                .ok()
1673            })
1674            .collect()
1675    }
1676    fn lower_residual_runtime(
1677        expression_ir: &ir::ExpressionIR,
1678        descriptors: &[PrecomputedCachedIntegral],
1679    ) -> Option<lowered::LoweredExpressionRuntime> {
1680        let mut zeroed_nodes = vec![false; expression_ir.node_count()];
1681        for descriptor in descriptors {
1682            if descriptor.mul_node_index < zeroed_nodes.len() {
1683                zeroed_nodes[descriptor.mul_node_index] = true;
1684            }
1685        }
1686        lowered::LoweredExpressionRuntime::from_ir_zeroed_value_gradient(
1687            expression_ir,
1688            &zeroed_nodes,
1689        )
1690        .ok()
1691    }
1692
1693    /// Number of slots required for buffers passed to expression amplitude evaluation helpers.
1694    ///
1695    /// This count is the number of amplitude use-sites in the expression, not the number of
1696    /// unique registered amplitude computations. Reused amplitudes therefore have one slot per
1697    /// expression use-site, and all expression value/gradient scratch buffers should be sized
1698    /// with this value.
1699    pub fn amplitude_value_slot_count(&self) -> usize {
1700        self.amplitude_use_sites.len()
1701    }
1702
1703    /// Fill amplitude values for one event in expression use-site order.
1704    ///
1705    /// `amplitude_values` must have length [`Evaluator::amplitude_value_slot_count`].
1706    /// `active_indices` must contain active amplitude use-site indices, such as those returned
1707    /// by [`Resources::active_indices`](crate::resources::Resources::active_indices) or derived
1708    /// from an active mask. The buffer is zeroed before active values are written. If multiple
1709    /// use-sites refer to the same registered amplitude computation, that amplitude is computed
1710    /// once and copied into each active use-site slot.
1711    #[inline]
1712    pub fn fill_amplitude_values(
1713        &self,
1714        amplitude_values: &mut [Complex64],
1715        active_indices: &[usize],
1716        parameters: &Parameters,
1717        cache: &Cache,
1718    ) {
1719        amplitude_values.fill(Complex64::ZERO);
1720        let mut compute_values = vec![Complex64::ZERO; self.amplitudes.len()];
1721        let mut computed = vec![false; self.amplitudes.len()];
1722        for &use_site_idx in active_indices {
1723            let amp_idx = self.amplitude_use_sites[use_site_idx].amplitude_index;
1724            if !computed[amp_idx] {
1725                compute_values[amp_idx] = self.amplitudes[amp_idx].compute(parameters, cache);
1726                computed[amp_idx] = true;
1727            }
1728            amplitude_values[use_site_idx] = compute_values[amp_idx];
1729        }
1730    }
1731
1732    #[inline]
1733    fn fill_amplitude_gradients(
1734        &self,
1735        gradient_values: &mut [DVector<Complex64>],
1736        active_mask: &[bool],
1737        parameters: &Parameters,
1738        cache: &Cache,
1739    ) {
1740        let mut compute_gradients = vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1741        let mut computed = vec![false; self.amplitudes.len()];
1742        for ((use_site, active), grad) in self
1743            .amplitude_use_sites
1744            .iter()
1745            .zip(active_mask.iter())
1746            .zip(gradient_values.iter_mut())
1747        {
1748            grad.fill(Complex64::ZERO);
1749            if *active {
1750                let amp_idx = use_site.amplitude_index;
1751                if !computed[amp_idx] {
1752                    self.amplitudes[amp_idx].compute_gradient(
1753                        parameters,
1754                        cache,
1755                        &mut compute_gradients[amp_idx],
1756                    );
1757                    computed[amp_idx] = true;
1758                }
1759                grad.copy_from(&compute_gradients[amp_idx]);
1760            }
1761        }
1762    }
1763
1764    /// Fill amplitude values and parameter gradients for one event in expression use-site order.
1765    ///
1766    /// `amplitude_values` and `gradient_values` must both have length
1767    /// [`Evaluator::amplitude_value_slot_count`]. `active_indices` and `active_mask` are both
1768    /// expressed in amplitude use-site space. Values and gradients for inactive use-sites are
1769    /// zeroed, while reused amplitude computations are evaluated once and copied into each active
1770    /// use-site slot that references them.
1771    #[inline]
1772    pub fn fill_amplitude_values_and_gradients(
1773        &self,
1774        amplitude_values: &mut [Complex64],
1775        gradient_values: &mut [DVector<Complex64>],
1776        active_indices: &[usize],
1777        active_mask: &[bool],
1778        parameters: &Parameters,
1779        cache: &Cache,
1780    ) {
1781        self.fill_amplitude_values(amplitude_values, active_indices, parameters, cache);
1782        self.fill_amplitude_gradients(gradient_values, active_mask, parameters, cache);
1783    }
1784
1785    #[cfg(feature = "execution-context-prototype")]
1786    #[inline]
1787    fn evaluate_cache_gradient_with_scratch(
1788        &self,
1789        amplitude_values: &mut [Complex64],
1790        gradient_values: &mut [DVector<Complex64>],
1791        value_slots: &mut [Complex64],
1792        gradient_slots: &mut [DVector<Complex64>],
1793        active_indices: &[usize],
1794        active_mask: &[bool],
1795        parameters: &Parameters,
1796        cache: &Cache,
1797    ) -> DVector<Complex64> {
1798        self.fill_amplitude_values_and_gradients(
1799            amplitude_values,
1800            gradient_values,
1801            active_indices,
1802            active_mask,
1803            parameters,
1804            cache,
1805        );
1806        self.evaluate_expression_gradient_with_scratch(
1807            amplitude_values,
1808            gradient_values,
1809            value_slots,
1810            gradient_slots,
1811        )
1812    }
1813
1814    #[cfg(feature = "execution-context-prototype")]
1815    #[allow(dead_code)]
1816    #[inline]
1817    fn evaluate_cache_value_gradient_with_scratch(
1818        &self,
1819        amplitude_values: &mut [Complex64],
1820        gradient_values: &mut [DVector<Complex64>],
1821        value_slots: &mut [Complex64],
1822        gradient_slots: &mut [DVector<Complex64>],
1823        active_indices: &[usize],
1824        active_mask: &[bool],
1825        parameters: &Parameters,
1826        cache: &Cache,
1827    ) -> (Complex64, DVector<Complex64>) {
1828        self.fill_amplitude_values_and_gradients(
1829            amplitude_values,
1830            gradient_values,
1831            active_indices,
1832            active_mask,
1833            parameters,
1834            cache,
1835        );
1836        self.evaluate_expression_value_gradient_with_scratch(
1837            amplitude_values,
1838            gradient_values,
1839            value_slots,
1840            gradient_slots,
1841        )
1842    }
1843
1844    pub fn expression_slot_count(&self) -> usize {
1845        self.lowered_runtime_slot_count()
1846    }
1847    fn compile_expression_ir_for_active_mask(&self, active_mask: &[bool]) -> ir::ExpressionIR {
1848        let amplitude_dependencies = self
1849            .amplitude_use_sites
1850            .iter()
1851            .map(|use_site| {
1852                ir::DependenceClass::from(
1853                    self.amplitudes[use_site.amplitude_index].dependence_hint(),
1854                )
1855            })
1856            .collect::<Vec<_>>();
1857        let amplitude_realness = self
1858            .amplitude_use_sites
1859            .iter()
1860            .map(|use_site| self.amplitudes[use_site.amplitude_index].real_valued_hint())
1861            .collect::<Vec<_>>();
1862        ir::compile_expression_ir_with_real_hints(
1863            &self.expression,
1864            active_mask,
1865            &amplitude_dependencies,
1866            &amplitude_realness,
1867        )
1868    }
1869    fn lower_expression_runtime_for_active_mask(
1870        &self,
1871        active_mask: &[bool],
1872    ) -> LadduResult<lowered::LoweredExpressionRuntime> {
1873        let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1874        lowered::LoweredExpressionRuntime::from_ir_value_gradient(&expression_ir).map_err(|error| {
1875            LadduError::Custom(format!(
1876                "Failed to lower active-mask runtime specialization: {error:?}"
1877            ))
1878        })
1879    }
1880    fn ensure_cached_integral_cache_state(
1881        &self,
1882        resources: &Resources,
1883    ) -> LadduResult<Arc<CachedIntegralCacheState>> {
1884        Ok(self
1885            .ensure_expression_specialization(resources)?
1886            .cached_integrals)
1887    }
1888
1889    fn evaluate_expression_runtime_value_with_scratch(
1890        &self,
1891        amplitude_values: &[Complex64],
1892        scratch: &mut [Complex64],
1893    ) -> Complex64 {
1894        let lowered_runtime = self.lowered_runtime();
1895        lowered_runtime
1896            .value_program()
1897            .evaluate_into(amplitude_values, scratch)
1898    }
1899
1900    #[doc(hidden)]
1901    pub fn evaluate_expression_value_with_program_snapshot(
1902        &self,
1903        program_snapshot: &ExpressionValueProgramSnapshot,
1904        amplitude_values: &[Complex64],
1905        scratch: &mut [Complex64],
1906    ) -> Complex64 {
1907        program_snapshot
1908            .lowered_program
1909            .evaluate_into(amplitude_values, scratch)
1910    }
1911
1912    fn evaluate_expression_runtime_gradient_with_scratch(
1913        &self,
1914        amplitude_values: &[Complex64],
1915        gradient_values: &[DVector<Complex64>],
1916        value_scratch: &mut [Complex64],
1917        gradient_scratch: &mut [DVector<Complex64>],
1918    ) -> DVector<Complex64> {
1919        let lowered_runtime = self.lowered_runtime();
1920        lowered_runtime.gradient_program().evaluate_gradient_into(
1921            amplitude_values,
1922            gradient_values,
1923            value_scratch,
1924            gradient_scratch,
1925        )
1926    }
1927
1928    fn evaluate_expression_runtime_value_gradient_with_scratch(
1929        &self,
1930        amplitude_values: &[Complex64],
1931        gradient_values: &[DVector<Complex64>],
1932        value_scratch: &mut [Complex64],
1933        gradient_scratch: &mut [DVector<Complex64>],
1934    ) -> (Complex64, DVector<Complex64>) {
1935        let lowered_runtime = self.lowered_runtime();
1936        lowered_runtime
1937            .value_gradient_program()
1938            .evaluate_value_gradient_into(
1939                amplitude_values,
1940                gradient_values,
1941                value_scratch,
1942                gradient_scratch,
1943            )
1944    }
1945
1946    fn evaluate_expression_runtime_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
1947        let lowered_runtime = self.lowered_runtime();
1948        let program = lowered_runtime.value_program();
1949        let mut scratch = vec![Complex64::ZERO; program.scratch_slots()];
1950        program.evaluate_into(amplitude_values, &mut scratch)
1951    }
1952
1953    fn evaluate_expression_runtime_gradient(
1954        &self,
1955        amplitude_values: &[Complex64],
1956        gradient_values: &[DVector<Complex64>],
1957    ) -> DVector<Complex64> {
1958        let lowered_runtime = self.lowered_runtime();
1959        let program = lowered_runtime.gradient_program();
1960        let mut value_scratch = vec![Complex64::ZERO; program.scratch_slots()];
1961        let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
1962        let mut gradient_scratch = vec![Complex64::ZERO; program.scratch_slots() * grad_dim];
1963        program.evaluate_gradient_into_flat(
1964            amplitude_values,
1965            gradient_values,
1966            &mut value_scratch,
1967            &mut gradient_scratch,
1968            grad_dim,
1969        )
1970    }
1971    /// Dependence classification for the compiled expression root.
1972    pub fn expression_root_dependence(&self) -> LadduResult<ExpressionDependence> {
1973        let resources = self.resources.read();
1974        Ok(self
1975            .ensure_cached_integral_cache_state(&resources)?
1976            .expression_ir
1977            .root_dependence()
1978            .into())
1979    }
1980    /// Dependence classification for each compiled expression node.
1981    pub fn expression_node_dependence_annotations(&self) -> LadduResult<Vec<ExpressionDependence>> {
1982        let resources = self.resources.read();
1983        Ok(self
1984            .ensure_cached_integral_cache_state(&resources)?
1985            .expression_ir
1986            .node_dependence_annotations()
1987            .iter()
1988            .copied()
1989            .map(Into::into)
1990            .collect())
1991    }
1992    /// Warning-level diagnostics for potentially inconsistent dependence hints.
1993    pub fn expression_dependence_warnings(&self) -> LadduResult<Vec<String>> {
1994        let resources = self.resources.read();
1995        Ok(self
1996            .ensure_cached_integral_cache_state(&resources)?
1997            .expression_ir
1998            .dependence_warnings()
1999            .to_vec())
2000    }
2001    /// Explain/debug view of IR normalization planning decomposition.
2002    pub fn expression_normalization_plan_explain(&self) -> LadduResult<NormalizationPlanExplain> {
2003        let resources = self.resources.read();
2004        Ok(self
2005            .ensure_cached_integral_cache_state(&resources)?
2006            .expression_ir
2007            .normalization_plan_explain()
2008            .into())
2009    }
2010    /// Explain/debug view of amplitude execution sets used by normalization evaluation.
2011    pub fn expression_normalization_execution_sets(
2012        &self,
2013    ) -> LadduResult<NormalizationExecutionSetsExplain> {
2014        let resources = self.resources.read();
2015        Ok(self
2016            .ensure_cached_integral_cache_state(&resources)?
2017            .execution_sets
2018            .clone()
2019            .into())
2020    }
2021    /// Cached integral terms precomputed at evaluator load.
2022    pub fn expression_precomputed_cached_integrals(
2023        &self,
2024    ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
2025        let resources = self.resources.read();
2026        Ok(self
2027            .ensure_cached_integral_cache_state(&resources)?
2028            .values
2029            .clone())
2030    }
2031    /// Derivative rules for cached separable terms evaluated at the given parameter point.
2032    ///
2033    /// Each returned term corresponds to a cached separable descriptor and contributes
2034    /// `weighted_gradient` to `d(normalization)/dp` prior to residual-term combination.
2035    pub fn expression_precomputed_cached_integral_gradient_terms(
2036        &self,
2037        parameters: &[f64],
2038    ) -> LadduResult<Vec<PrecomputedCachedIntegralGradientTerm>> {
2039        let resources = self.resources.read();
2040        let state = self.ensure_cached_integral_cache_state(&resources)?;
2041        if state.values.is_empty() {
2042            return Ok(Vec::new());
2043        }
2044
2045        let Some(cache) = resources.caches.first() else {
2046            return Ok(state
2047                .values
2048                .iter()
2049                .map(|descriptor| PrecomputedCachedIntegralGradientTerm {
2050                    mul_node_index: descriptor.mul_node_index,
2051                    parameter_node_index: descriptor.parameter_node_index,
2052                    cache_node_index: descriptor.cache_node_index,
2053                    coefficient: descriptor.coefficient,
2054                    weighted_gradient: DVector::zeros(parameters.len()),
2055                })
2056                .collect());
2057        };
2058
2059        let parameter_values = resources.parameter_map.assemble(parameters)?;
2060        let mut amplitude_values = vec![Complex64::ZERO; self.amplitude_use_sites.len()];
2061        self.fill_amplitude_values(
2062            &mut amplitude_values,
2063            resources.active_indices(),
2064            &parameter_values,
2065            cache,
2066        );
2067        let mut amplitude_gradients = (0..self.amplitude_use_sites.len())
2068            .map(|_| DVector::zeros(parameters.len()))
2069            .collect::<Vec<_>>();
2070        self.fill_amplitude_gradients(
2071            &mut amplitude_gradients,
2072            &resources.active,
2073            &parameter_values,
2074            cache,
2075        );
2076        let lowered_artifacts = self.active_lowered_artifacts();
2077        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2078        let mut gradient_slots = (0..state.expression_ir.node_count())
2079            .map(|_| DVector::zeros(parameters.len()))
2080            .collect::<Vec<_>>();
2081        let max_lowered_slots = lowered_artifacts
2082            .as_ref()
2083            .map(|artifacts| {
2084                artifacts
2085                    .lowered_parameter_factors
2086                    .iter()
2087                    .filter_map(|runtime| {
2088                        runtime
2089                            .as_ref()
2090                            .and_then(|runtime| runtime.gradient_program())
2091                            .map(|program| program.scratch_slots())
2092                    })
2093                    .max()
2094                    .unwrap_or(0)
2095            })
2096            .unwrap_or(0);
2097        let mut lowered_value_slots = vec![Complex64::ZERO; max_lowered_slots];
2098        let mut lowered_gradient_slots = vec![DVector::zeros(parameters.len()); max_lowered_slots];
2099        let use_lowered = lowered_artifacts.as_ref().is_some_and(|artifacts| {
2100            artifacts.lowered_parameter_factors.len() == state.values.len()
2101                && artifacts.lowered_parameter_factors.iter().all(|runtime| {
2102                    runtime
2103                        .as_ref()
2104                        .and_then(|runtime| runtime.gradient_program())
2105                        .is_some()
2106                })
2107        });
2108
2109        if !use_lowered {
2110            let _ = state.expression_ir.evaluate_gradient_into(
2111                &amplitude_values,
2112                &amplitude_gradients,
2113                &mut value_slots,
2114                &mut gradient_slots,
2115            );
2116        }
2117
2118        if use_lowered {
2119            let lowered_artifacts = lowered_artifacts.expect("lowered artifacts should exist");
2120            Ok(state
2121                .values
2122                .iter()
2123                .cloned()
2124                .zip(lowered_artifacts.lowered_parameter_factors.iter())
2125                .map(|(descriptor, runtime)| {
2126                    let parameter_gradient = runtime
2127                        .as_ref()
2128                        .and_then(|runtime| runtime.gradient_program())
2129                        .map(|program| {
2130                            program.evaluate_gradient_into(
2131                                &amplitude_values,
2132                                &amplitude_gradients,
2133                                &mut lowered_value_slots[..program.scratch_slots()],
2134                                &mut lowered_gradient_slots[..program.scratch_slots()],
2135                            )
2136                        })
2137                        .unwrap_or_else(|| DVector::zeros(parameters.len()));
2138                    let weighted_gradient = parameter_gradient.map(|value| {
2139                        value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2140                    });
2141                    PrecomputedCachedIntegralGradientTerm {
2142                        mul_node_index: descriptor.mul_node_index,
2143                        parameter_node_index: descriptor.parameter_node_index,
2144                        cache_node_index: descriptor.cache_node_index,
2145                        coefficient: descriptor.coefficient,
2146                        weighted_gradient,
2147                    }
2148                })
2149                .collect())
2150        } else {
2151            Ok(state
2152                .values
2153                .iter()
2154                .map(|descriptor| {
2155                    let parameter_gradient = gradient_slots
2156                        .get(descriptor.parameter_node_index)
2157                        .cloned()
2158                        .unwrap_or_else(|| DVector::zeros(parameters.len()));
2159                    let weighted_gradient = parameter_gradient.map(|value| {
2160                        value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2161                    });
2162                    PrecomputedCachedIntegralGradientTerm {
2163                        mul_node_index: descriptor.mul_node_index,
2164                        parameter_node_index: descriptor.parameter_node_index,
2165                        cache_node_index: descriptor.cache_node_index,
2166                        coefficient: descriptor.coefficient,
2167                        weighted_gradient,
2168                    }
2169                })
2170                .collect())
2171        }
2172    }
2173    fn evaluate_cached_weighted_value_sum_ir(
2174        &self,
2175        state: &CachedIntegralCacheState,
2176        amplitude_values: &[Complex64],
2177    ) -> f64 {
2178        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2179        let _ = state
2180            .expression_ir
2181            .evaluate_into(amplitude_values, &mut value_slots);
2182        state
2183            .values
2184            .iter()
2185            .map(|descriptor| {
2186                let parameter_factor = value_slots[descriptor.parameter_node_index];
2187                (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2188                    .re
2189            })
2190            .sum()
2191    }
2192    fn evaluate_cached_weighted_value_sum_lowered(
2193        &self,
2194        state: &CachedIntegralCacheState,
2195        lowered_artifacts: &LoweredArtifactCacheState,
2196        amplitude_values: &[Complex64],
2197    ) -> Option<f64> {
2198        let max_slots = lowered_artifacts
2199            .lowered_parameter_factors
2200            .iter()
2201            .filter_map(|runtime| {
2202                runtime
2203                    .as_ref()
2204                    .and_then(|runtime| runtime.value_program())
2205                    .map(|program| program.scratch_slots())
2206            })
2207            .max()
2208            .unwrap_or(0);
2209        let mut value_slots = vec![Complex64::ZERO; max_slots];
2210        let mut total = 0.0;
2211        for (descriptor, runtime) in state
2212            .values
2213            .iter()
2214            .zip(lowered_artifacts.lowered_parameter_factors.iter())
2215        {
2216            let parameter_factor = runtime
2217                .as_ref()
2218                .and_then(|runtime| runtime.value_program())
2219                .map(|program| {
2220                    program.evaluate_into(
2221                        amplitude_values,
2222                        &mut value_slots[..program.scratch_slots()],
2223                    )
2224                })?;
2225            total +=
2226                (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2227                    .re;
2228        }
2229        Some(total)
2230    }
2231    fn evaluate_cached_weighted_gradient_sum_ir(
2232        &self,
2233        state: &CachedIntegralCacheState,
2234        amplitude_values: &[Complex64],
2235        amplitude_gradients: &[DVector<Complex64>],
2236        grad_dim: usize,
2237    ) -> DVector<f64> {
2238        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2239        let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2240        let _ = state.expression_ir.evaluate_gradient_into(
2241            amplitude_values,
2242            amplitude_gradients,
2243            &mut value_slots,
2244            &mut gradient_slots,
2245        );
2246        state
2247            .values
2248            .iter()
2249            .fold(DVector::zeros(grad_dim), |mut accum, descriptor| {
2250                let parameter_gradient = &gradient_slots[descriptor.parameter_node_index];
2251                let coefficient = descriptor.coefficient as f64;
2252                for (accum_item, gradient_item) in accum.iter_mut().zip(parameter_gradient.iter()) {
2253                    *accum_item +=
2254                        (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2255                }
2256                accum
2257            })
2258    }
2259    fn evaluate_cached_weighted_gradient_sum_lowered(
2260        &self,
2261        state: &CachedIntegralCacheState,
2262        lowered_artifacts: &LoweredArtifactCacheState,
2263        amplitude_values: &[Complex64],
2264        amplitude_gradients: &[DVector<Complex64>],
2265        grad_dim: usize,
2266    ) -> Option<DVector<f64>> {
2267        let max_value_slots = lowered_artifacts
2268            .lowered_parameter_factors
2269            .iter()
2270            .filter_map(|runtime| {
2271                runtime
2272                    .as_ref()
2273                    .and_then(|runtime| runtime.gradient_program())
2274                    .map(|program| program.scratch_slots())
2275            })
2276            .max()
2277            .unwrap_or(0);
2278        let mut value_slots = vec![Complex64::ZERO; max_value_slots];
2279        let mut gradient_slots = vec![Complex64::ZERO; max_value_slots * grad_dim];
2280        let mut total = DVector::zeros(grad_dim);
2281        for (descriptor, runtime) in state
2282            .values
2283            .iter()
2284            .zip(lowered_artifacts.lowered_parameter_factors.iter())
2285        {
2286            let parameter_gradient = runtime
2287                .as_ref()
2288                .and_then(|runtime| runtime.gradient_program())
2289                .map(|program| {
2290                    program.evaluate_gradient_into_flat(
2291                        amplitude_values,
2292                        amplitude_gradients,
2293                        &mut value_slots[..program.scratch_slots()],
2294                        &mut gradient_slots[..program.scratch_slots() * grad_dim],
2295                        grad_dim,
2296                    )
2297                })?;
2298            let coefficient = descriptor.coefficient as f64;
2299            for (accum_item, gradient_item) in total.iter_mut().zip(parameter_gradient.iter()) {
2300                *accum_item += (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2301            }
2302        }
2303        Some(total)
2304    }
2305    fn evaluate_residual_value_ir(
2306        &self,
2307        state: &CachedIntegralCacheState,
2308        amplitude_values: &[Complex64],
2309    ) -> Complex64 {
2310        let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2311        for descriptor in &state.values {
2312            if descriptor.mul_node_index < zeroed_nodes.len() {
2313                zeroed_nodes[descriptor.mul_node_index] = true;
2314            }
2315        }
2316        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2317        state.expression_ir.evaluate_into_with_zeroed_nodes(
2318            amplitude_values,
2319            &mut value_slots,
2320            &zeroed_nodes,
2321        )
2322    }
2323    fn evaluate_residual_gradient_ir(
2324        &self,
2325        state: &CachedIntegralCacheState,
2326        amplitude_values: &[Complex64],
2327        amplitude_gradients: &[DVector<Complex64>],
2328        grad_dim: usize,
2329    ) -> DVector<Complex64> {
2330        let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2331        for descriptor in &state.values {
2332            if descriptor.mul_node_index < zeroed_nodes.len() {
2333                zeroed_nodes[descriptor.mul_node_index] = true;
2334            }
2335        }
2336        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2337        let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2338        state
2339            .expression_ir
2340            .evaluate_gradient_into_with_zeroed_nodes(
2341                amplitude_values,
2342                amplitude_gradients,
2343                &mut value_slots,
2344                &mut gradient_slots,
2345                &zeroed_nodes,
2346            )
2347    }
2348
2349    fn evaluate_weighted_value_sum_local_components(
2350        &self,
2351        parameters: &[f64],
2352    ) -> LadduResult<(f64, f64)> {
2353        let resources = self.resources.read();
2354        let parameters = resources.parameter_map.assemble(parameters)?;
2355        let amplitude_len = self.amplitude_use_sites.len();
2356        let state = self.ensure_cached_integral_cache_state(&resources)?;
2357        let lowered_artifacts = self.active_lowered_artifacts();
2358        let residual_value_slot_count = lowered_artifacts
2359            .as_ref()
2360            .and_then(|artifacts| {
2361                artifacts
2362                    .residual_runtime
2363                    .as_ref()
2364                    .map(|runtime| runtime.value_program())
2365                    .map(|program| program.scratch_slots())
2366            })
2367            .unwrap_or_else(|| self.expression_slot_count());
2368        let residual_value_program = lowered_artifacts
2369            .as_ref()
2370            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2371            .map(|runtime| runtime.value_program());
2372        let cached_parameter_indices = &state.execution_sets.cached_parameter_amplitudes;
2373        let residual_active_indices = &state.execution_sets.residual_amplitudes;
2374        debug_assert!(cached_parameter_indices.iter().all(|&index| resources
2375            .active
2376            .get(index)
2377            .copied()
2378            .unwrap_or(false)));
2379        debug_assert!(residual_active_indices.iter().all(|&index| resources
2380            .active
2381            .get(index)
2382            .copied()
2383            .unwrap_or(false)));
2384        let cached_value_sum = {
2385            if let Some(cache) = resources.caches.first() {
2386                let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2387                self.fill_amplitude_values(
2388                    &mut amplitude_values,
2389                    cached_parameter_indices,
2390                    &parameters,
2391                    cache,
2392                );
2393                lowered_artifacts
2394                    .as_ref()
2395                    .and_then(|artifacts| {
2396                        self.evaluate_cached_weighted_value_sum_lowered(
2397                            &state,
2398                            artifacts,
2399                            &amplitude_values,
2400                        )
2401                    })
2402                    .unwrap_or_else(|| {
2403                        self.evaluate_cached_weighted_value_sum_ir(&state, &amplitude_values)
2404                    })
2405            } else {
2406                0.0
2407            }
2408        };
2409
2410        #[cfg(feature = "rayon")]
2411        let residual_sum: f64 = {
2412            resources
2413                .caches
2414                .par_iter()
2415                .zip(self.dataset.weights_local().par_iter())
2416                .map_init(
2417                    || {
2418                        (
2419                            vec![Complex64::ZERO; amplitude_len],
2420                            vec![Complex64::ZERO; residual_value_slot_count],
2421                        )
2422                    },
2423                    |(amplitude_values, value_slots), (cache, event)| {
2424                        self.fill_amplitude_values(
2425                            amplitude_values,
2426                            residual_active_indices,
2427                            &parameters,
2428                            cache,
2429                        );
2430                        {
2431                            let value = residual_value_program
2432                                .as_ref()
2433                                .map(|program| {
2434                                    program.evaluate_into(
2435                                        amplitude_values,
2436                                        &mut value_slots[..program.scratch_slots()],
2437                                    )
2438                                })
2439                                .unwrap_or_else(|| {
2440                                    self.evaluate_residual_value_ir(&state, amplitude_values)
2441                                });
2442                            *event * value.re
2443                        }
2444                    },
2445                )
2446                .sum()
2447        };
2448
2449        #[cfg(not(feature = "rayon"))]
2450        let residual_sum: f64 = {
2451            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2452            let mut value_slots = vec![Complex64::ZERO; residual_value_slot_count];
2453            resources
2454                .caches
2455                .iter()
2456                .zip(self.dataset.weights_local().iter())
2457                .map(|(cache, event)| {
2458                    self.fill_amplitude_values(
2459                        &mut amplitude_values,
2460                        &residual_active_indices,
2461                        &parameters,
2462                        cache,
2463                    );
2464                    {
2465                        let value = residual_value_program
2466                            .as_ref()
2467                            .map(|program| {
2468                                program.evaluate_into(
2469                                    &amplitude_values,
2470                                    &mut value_slots[..program.scratch_slots()],
2471                                )
2472                            })
2473                            .unwrap_or_else(|| {
2474                                self.evaluate_residual_value_ir(&state, &amplitude_values)
2475                            });
2476                        *event * value.re
2477                    }
2478                })
2479                .sum()
2480        };
2481        Ok((residual_sum, cached_value_sum))
2482    }
2483
2484    /// Weighted sum over local events of the real expression value.
2485    ///
2486    /// This returns `sum_e(weight_e * Re(L_e))`.
2487    pub fn evaluate_weighted_value_sum_local(&self, parameters: &[f64]) -> LadduResult<f64> {
2488        let (residual_sum, cached_value_sum) =
2489            self.evaluate_weighted_value_sum_local_components(parameters)?;
2490        Ok(residual_sum + cached_value_sum)
2491    }
2492
2493    #[cfg(feature = "mpi")]
2494    /// Weighted sum over all ranks of the real expression value.
2495    ///
2496    /// This returns `sum_{r,e}(weight_{r,e} * Re(L_{r,e}))`.
2497    pub fn evaluate_weighted_value_sum_mpi(
2498        &self,
2499        parameters: &[f64],
2500        world: &SimpleCommunicator,
2501    ) -> LadduResult<f64> {
2502        let (residual_sum_local, cached_value_sum_local) =
2503            self.evaluate_weighted_value_sum_local_components(parameters)?;
2504        let mut residual_sum = 0.0;
2505        world.all_reduce_into(
2506            &residual_sum_local,
2507            &mut residual_sum,
2508            mpi::collective::SystemOperation::sum(),
2509        );
2510        let mut cached_value_sum = 0.0;
2511        world.all_reduce_into(
2512            &cached_value_sum_local,
2513            &mut cached_value_sum,
2514            mpi::collective::SystemOperation::sum(),
2515        );
2516        Ok(residual_sum + cached_value_sum)
2517    }
2518
2519    /// Weighted sum over local events of the real gradient of the expression.
2520    ///
2521    /// This returns `sum_e(weight_e * Re(dL_e/dp))` for all free parameters.
2522    fn evaluate_weighted_gradient_sum_local_components(
2523        &self,
2524        parameters: &[f64],
2525    ) -> LadduResult<(DVector<f64>, DVector<f64>)> {
2526        let resources = self.resources.read();
2527        let parameters = resources.parameter_map.assemble(parameters)?;
2528        let amplitude_len = self.amplitude_use_sites.len();
2529        let grad_dim = parameters.len();
2530        let state = self.ensure_cached_integral_cache_state(&resources)?;
2531        let lowered_artifacts = self.active_lowered_artifacts();
2532        let active_index_set = resources.active_indices();
2533        let cached_parameter_indices = state
2534            .execution_sets
2535            .cached_parameter_amplitudes
2536            .iter()
2537            .copied()
2538            .filter(|index| active_index_set.binary_search(index).is_ok())
2539            .collect::<Vec<_>>();
2540        let residual_active_indices = state
2541            .execution_sets
2542            .residual_amplitudes
2543            .iter()
2544            .copied()
2545            .filter(|index| active_index_set.binary_search(index).is_ok())
2546            .collect::<Vec<_>>();
2547        let mut cached_parameter_mask = vec![false; amplitude_len];
2548        for &index in &cached_parameter_indices {
2549            cached_parameter_mask[index] = true;
2550        }
2551        let mut residual_active_mask = vec![false; amplitude_len];
2552        for &index in &residual_active_indices {
2553            residual_active_mask[index] = true;
2554        }
2555        let residual_gradient_program = lowered_artifacts
2556            .as_ref()
2557            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2558            .map(|runtime| runtime.gradient_program());
2559        let residual_gradient_slot_count = residual_gradient_program
2560            .as_ref()
2561            .map(|program| program.scratch_slots())
2562            .unwrap_or_else(|| state.expression_ir.node_count());
2563        let cached_term_sum = {
2564            if let Some(cache) = resources.caches.first() {
2565                let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2566                self.fill_amplitude_values(
2567                    &mut amplitude_values,
2568                    &cached_parameter_indices,
2569                    &parameters,
2570                    cache,
2571                );
2572                let mut amplitude_gradients = (0..amplitude_len)
2573                    .map(|_| DVector::zeros(grad_dim))
2574                    .collect::<Vec<_>>();
2575                self.fill_amplitude_gradients(
2576                    &mut amplitude_gradients,
2577                    &cached_parameter_mask,
2578                    &parameters,
2579                    cache,
2580                );
2581                lowered_artifacts
2582                    .as_ref()
2583                    .and_then(|artifacts| {
2584                        self.evaluate_cached_weighted_gradient_sum_lowered(
2585                            &state,
2586                            artifacts,
2587                            &amplitude_values,
2588                            &amplitude_gradients,
2589                            grad_dim,
2590                        )
2591                    })
2592                    .unwrap_or_else(|| {
2593                        self.evaluate_cached_weighted_gradient_sum_ir(
2594                            &state,
2595                            &amplitude_values,
2596                            &amplitude_gradients,
2597                            grad_dim,
2598                        )
2599                    })
2600            } else {
2601                DVector::zeros(grad_dim)
2602            }
2603        };
2604
2605        #[cfg(feature = "rayon")]
2606        let residual_sum = {
2607            resources
2608                .caches
2609                .par_iter()
2610                .zip(self.dataset.weights_local().par_iter())
2611                .map_init(
2612                    || {
2613                        (
2614                            vec![Complex64::ZERO; amplitude_len],
2615                            vec![DVector::zeros(grad_dim); amplitude_len],
2616                            vec![Complex64::ZERO; residual_gradient_slot_count],
2617                            vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim],
2618                        )
2619                    },
2620                    |(amplitude_values, gradient_values, value_slots, gradient_slots),
2621                     (cache, event)| {
2622                        self.fill_amplitude_values_and_gradients(
2623                            amplitude_values,
2624                            gradient_values,
2625                            &residual_active_indices,
2626                            &residual_active_mask,
2627                            &parameters,
2628                            cache,
2629                        );
2630                        let gradient = residual_gradient_program
2631                            .as_ref()
2632                            .map(|program| {
2633                                program.evaluate_gradient_into_flat(
2634                                    amplitude_values,
2635                                    gradient_values,
2636                                    value_slots,
2637                                    gradient_slots,
2638                                    grad_dim,
2639                                )
2640                            })
2641                            .unwrap_or_else(|| {
2642                                self.evaluate_residual_gradient_ir(
2643                                    &state,
2644                                    amplitude_values,
2645                                    gradient_values,
2646                                    grad_dim,
2647                                )
2648                            });
2649                        gradient.map(|value| value.re).scale(*event)
2650                    },
2651                )
2652                .reduce(
2653                    || DVector::zeros(grad_dim),
2654                    |mut accum, value| {
2655                        accum += value;
2656                        accum
2657                    },
2658                )
2659        };
2660
2661        #[cfg(not(feature = "rayon"))]
2662        let residual_sum = {
2663            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2664            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
2665            let mut value_slots = vec![Complex64::ZERO; residual_gradient_slot_count];
2666            let mut gradient_slots = vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim];
2667            resources
2668                .caches
2669                .iter()
2670                .zip(self.dataset.weights_local().iter())
2671                .map(|(cache, event)| {
2672                    self.fill_amplitude_values_and_gradients(
2673                        &mut amplitude_values,
2674                        &mut gradient_values,
2675                        &residual_active_indices,
2676                        &residual_active_mask,
2677                        &parameters,
2678                        cache,
2679                    );
2680                    let gradient = residual_gradient_program
2681                        .as_ref()
2682                        .map(|program| {
2683                            program.evaluate_gradient_into_flat(
2684                                &amplitude_values,
2685                                &gradient_values,
2686                                &mut value_slots,
2687                                &mut gradient_slots,
2688                                grad_dim,
2689                            )
2690                        })
2691                        .unwrap_or_else(|| {
2692                            self.evaluate_residual_gradient_ir(
2693                                &state,
2694                                &amplitude_values,
2695                                &gradient_values,
2696                                grad_dim,
2697                            )
2698                        });
2699                    gradient.map(|value| value.re).scale(*event)
2700                })
2701                .sum()
2702        };
2703        Ok((residual_sum, cached_term_sum))
2704    }
2705
2706    /// Weighted sum over local events of the real gradient of the expression.
2707    ///
2708    /// This returns `sum_e(weight_e * Re(dL_e/dp))` for all free parameters.
2709    pub fn evaluate_weighted_gradient_sum_local(
2710        &self,
2711        parameters: &[f64],
2712    ) -> LadduResult<DVector<f64>> {
2713        let (residual_sum, cached_term_sum) =
2714            self.evaluate_weighted_gradient_sum_local_components(parameters)?;
2715        Ok(residual_sum + cached_term_sum)
2716    }
2717
2718    #[cfg(feature = "mpi")]
2719    /// Weighted sum over all ranks of the real gradient of the expression.
2720    ///
2721    /// This returns `sum_{r,e}(weight_{r,e} * Re(dL_{r,e}/dp))`.
2722    pub fn evaluate_weighted_gradient_sum_mpi(
2723        &self,
2724        parameters: &[f64],
2725        world: &SimpleCommunicator,
2726    ) -> LadduResult<DVector<f64>> {
2727        let (residual_sum_local, cached_term_sum_local) =
2728            self.evaluate_weighted_gradient_sum_local_components(parameters)?;
2729        let mut residual_sum = vec![0.0; residual_sum_local.len()];
2730        world.all_reduce_into(
2731            residual_sum_local.as_slice(),
2732            &mut residual_sum,
2733            mpi::collective::SystemOperation::sum(),
2734        );
2735        let mut cached_term_sum = vec![0.0; cached_term_sum_local.len()];
2736        world.all_reduce_into(
2737            cached_term_sum_local.as_slice(),
2738            &mut cached_term_sum,
2739            mpi::collective::SystemOperation::sum(),
2740        );
2741        let mut total = DVector::from_vec(residual_sum);
2742        total += DVector::from_vec(cached_term_sum);
2743        Ok(total)
2744    }
2745
2746    pub fn evaluate_expression_value_with_scratch(
2747        &self,
2748        amplitude_values: &[Complex64],
2749        scratch: &mut [Complex64],
2750    ) -> Complex64 {
2751        self.evaluate_expression_runtime_value_with_scratch(amplitude_values, scratch)
2752    }
2753
2754    pub fn evaluate_expression_gradient_with_scratch(
2755        &self,
2756        amplitude_values: &[Complex64],
2757        gradient_values: &[DVector<Complex64>],
2758        value_scratch: &mut [Complex64],
2759        gradient_scratch: &mut [DVector<Complex64>],
2760    ) -> DVector<Complex64> {
2761        self.evaluate_expression_runtime_gradient_with_scratch(
2762            amplitude_values,
2763            gradient_values,
2764            value_scratch,
2765            gradient_scratch,
2766        )
2767    }
2768
2769    pub fn evaluate_expression_value_gradient_with_scratch(
2770        &self,
2771        amplitude_values: &[Complex64],
2772        gradient_values: &[DVector<Complex64>],
2773        value_scratch: &mut [Complex64],
2774        gradient_scratch: &mut [DVector<Complex64>],
2775    ) -> (Complex64, DVector<Complex64>) {
2776        self.evaluate_expression_runtime_value_gradient_with_scratch(
2777            amplitude_values,
2778            gradient_values,
2779            value_scratch,
2780            gradient_scratch,
2781        )
2782    }
2783
2784    pub fn evaluate_expression_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
2785        self.evaluate_expression_runtime_value(amplitude_values)
2786    }
2787
2788    pub fn evaluate_expression_gradient(
2789        &self,
2790        amplitude_values: &[Complex64],
2791        gradient_values: &[DVector<Complex64>],
2792    ) -> DVector<Complex64> {
2793        self.evaluate_expression_runtime_gradient(amplitude_values, gradient_values)
2794    }
2795
2796    /// Get the parameters used by this evaluator.
2797    pub fn parameters(&self) -> ParameterMap {
2798        self.resources.read().parameters()
2799    }
2800
2801    /// Number of free parameters.
2802    pub fn n_free(&self) -> usize {
2803        self.resources.read().n_free_parameters()
2804    }
2805
2806    /// Number of fixed parameters.
2807    pub fn n_fixed(&self) -> usize {
2808        self.resources.read().n_fixed_parameters()
2809    }
2810
2811    /// Total number of parameters.
2812    pub fn n_parameters(&self) -> usize {
2813        self.resources.read().n_parameters()
2814    }
2815
2816    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
2817        self.resources.read().fix_parameter(name, value)
2818    }
2819
2820    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
2821        self.resources.read().free_parameter(name)
2822    }
2823
2824    pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
2825        self.resources.write().rename_parameter(old, new)
2826    }
2827
2828    pub fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
2829        self.resources.write().rename_parameters(mapping)
2830    }
2831
2832    /// Activate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2833    pub fn activate<T: AsRef<str>>(&self, name: T) {
2834        self.resources.write().activate(name);
2835        self.refresh_runtime_specializations();
2836    }
2837    /// Activate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
2838    pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2839        self.resources.write().activate_strict(name)?;
2840        self.refresh_runtime_specializations();
2841        Ok(())
2842    }
2843
2844    /// Activate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2845    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
2846        self.resources.write().activate_many(names);
2847        self.refresh_runtime_specializations();
2848    }
2849    /// Activate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
2850    pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2851        self.resources.write().activate_many_strict(names)?;
2852        self.refresh_runtime_specializations();
2853        Ok(())
2854    }
2855
2856    /// Activate all registered [`Amplitude`]s.
2857    pub fn activate_all(&self) {
2858        self.resources.write().activate_all();
2859        self.refresh_runtime_specializations();
2860    }
2861
2862    /// Deactivate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2863    pub fn deactivate<T: AsRef<str>>(&self, name: T) {
2864        self.resources.write().deactivate(name);
2865        self.refresh_runtime_specializations();
2866    }
2867
2868    /// Deactivate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
2869    pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2870        self.resources.write().deactivate_strict(name)?;
2871        self.refresh_runtime_specializations();
2872        Ok(())
2873    }
2874
2875    /// Deactivate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2876    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
2877        self.resources.write().deactivate_many(names);
2878        self.refresh_runtime_specializations();
2879    }
2880    /// Deactivate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
2881    pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2882        self.resources.write().deactivate_many_strict(names)?;
2883        self.refresh_runtime_specializations();
2884        Ok(())
2885    }
2886
2887    /// Deactivate all tagged [`Amplitude`] use-sites.
2888    pub fn deactivate_all(&self) {
2889        self.resources.write().deactivate_all();
2890        self.refresh_runtime_specializations();
2891    }
2892
2893    /// Isolate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2894    pub fn isolate<T: AsRef<str>>(&self, name: T) {
2895        self.resources.write().isolate(name);
2896        self.refresh_runtime_specializations();
2897    }
2898
2899    /// Isolate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
2900    pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2901        self.resources.write().isolate_strict(name)?;
2902        self.refresh_runtime_specializations();
2903        Ok(())
2904    }
2905
2906    /// Isolate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2907    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
2908        self.resources.write().isolate_many(names);
2909        self.refresh_runtime_specializations();
2910    }
2911
2912    /// Isolate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
2913    pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2914        self.resources.write().isolate_many_strict(names)?;
2915        self.refresh_runtime_specializations();
2916        Ok(())
2917    }
2918
2919    /// Return a copy of the current active-amplitude mask.
2920    pub fn active_mask(&self) -> Vec<bool> {
2921        self.resources.read().active.clone()
2922    }
2923
2924    /// Apply a precomputed active-amplitude mask. Untagged use-sites cannot be deactivated.
2925    pub fn set_active_mask(&self, mask: &[bool]) -> LadduResult<()> {
2926        let resources = {
2927            let mut resources = self.resources.write();
2928            if mask.len() != resources.active.len() {
2929                return Err(LadduError::LengthMismatch {
2930                    context: "active amplitude mask".to_string(),
2931                    expected: resources.active.len(),
2932                    actual: mask.len(),
2933                });
2934            }
2935            resources.apply_active_mask(mask)?;
2936            resources.clone()
2937        };
2938        self.rebuild_runtime_specializations(&resources);
2939        Ok(())
2940    }
2941
2942    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
2943    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
2944    ///
2945    /// # Notes
2946    ///
2947    /// This method is not intended to be called in analyses but rather in writing methods
2948    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
2949    pub fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
2950        let resources = self.resources.read();
2951        let parameters = resources.parameter_map.assemble(parameters)?;
2952        let amplitude_len = self.amplitude_use_sites.len();
2953        let active_indices = resources.active_indices().to_vec();
2954        let slot_count = self.expression_value_slot_count();
2955        let program_snapshot = self.expression_value_program_snapshot();
2956        #[cfg(feature = "rayon")]
2957        {
2958            Ok(resources
2959                .caches
2960                .par_iter()
2961                .map_init(
2962                    || {
2963                        (
2964                            vec![Complex64::ZERO; amplitude_len],
2965                            vec![Complex64::ZERO; slot_count],
2966                        )
2967                    },
2968                    |(amplitude_values, expr_slots), cache| {
2969                        self.fill_amplitude_values(
2970                            amplitude_values,
2971                            &active_indices,
2972                            &parameters,
2973                            cache,
2974                        );
2975                        self.evaluate_expression_value_with_program_snapshot(
2976                            &program_snapshot,
2977                            amplitude_values,
2978                            expr_slots,
2979                        )
2980                    },
2981                )
2982                .collect())
2983        }
2984        #[cfg(not(feature = "rayon"))]
2985        {
2986            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2987            let mut expr_slots = vec![Complex64::ZERO; slot_count];
2988            Ok(resources
2989                .caches
2990                .iter()
2991                .map(|cache| {
2992                    self.fill_amplitude_values(
2993                        &mut amplitude_values,
2994                        &active_indices,
2995                        &parameters,
2996                        cache,
2997                    );
2998                    self.evaluate_expression_value_with_program_snapshot(
2999                        &program_snapshot,
3000                        &amplitude_values,
3001                        &mut expr_slots,
3002                    )
3003                })
3004                .collect())
3005        }
3006    }
3007
3008    /// Evaluate local events using an explicit active-amplitude mask without mutating evaluator state.
3009    pub fn evaluate_local_with_active_mask(
3010        &self,
3011        parameters: &[f64],
3012        active_mask: &[bool],
3013    ) -> LadduResult<Vec<Complex64>> {
3014        let resources = self.resources.read();
3015        if active_mask.len() != resources.active.len() {
3016            return Err(LadduError::LengthMismatch {
3017                context: "active amplitude mask".to_string(),
3018                expected: resources.active.len(),
3019                actual: active_mask.len(),
3020            });
3021        }
3022        let parameters = resources.parameter_map.assemble(parameters)?;
3023        let amplitude_len = self.amplitude_use_sites.len();
3024        let active_indices = active_mask
3025            .iter()
3026            .enumerate()
3027            .filter_map(|(index, &active)| if active { Some(index) } else { None })
3028            .collect::<Vec<_>>();
3029        let program_snapshot =
3030            self.expression_value_program_snapshot_for_active_mask(active_mask)?;
3031        let slot_count = self.expression_value_program_snapshot_slot_count(&program_snapshot);
3032        #[cfg(feature = "rayon")]
3033        {
3034            Ok(resources
3035                .caches
3036                .par_iter()
3037                .map_init(
3038                    || {
3039                        (
3040                            vec![Complex64::ZERO; amplitude_len],
3041                            vec![Complex64::ZERO; slot_count],
3042                        )
3043                    },
3044                    |(amplitude_values, expr_slots), cache| {
3045                        self.fill_amplitude_values(
3046                            amplitude_values,
3047                            &active_indices,
3048                            &parameters,
3049                            cache,
3050                        );
3051                        self.evaluate_expression_value_with_program_snapshot(
3052                            &program_snapshot,
3053                            amplitude_values,
3054                            expr_slots,
3055                        )
3056                    },
3057                )
3058                .collect())
3059        }
3060        #[cfg(not(feature = "rayon"))]
3061        {
3062            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3063            let mut expr_slots = vec![Complex64::ZERO; slot_count];
3064            Ok(resources
3065                .caches
3066                .iter()
3067                .map(|cache| {
3068                    self.fill_amplitude_values(
3069                        &mut amplitude_values,
3070                        &active_indices,
3071                        &parameters,
3072                        cache,
3073                    );
3074                    self.evaluate_expression_value_with_program_snapshot(
3075                        &program_snapshot,
3076                        &amplitude_values,
3077                        &mut expr_slots,
3078                    )
3079                })
3080                .collect())
3081        }
3082    }
3083
3084    /// Evaluate the stored expression over local events using a reusable execution context.
3085    #[cfg(feature = "execution-context-prototype")]
3086    pub fn evaluate_local_with_ctx(
3087        &self,
3088        parameters: &[f64],
3089        execution_context: &ExecutionContext,
3090    ) -> Vec<Complex64> {
3091        let resources = self.resources.read();
3092        let parameters = resources
3093            .parameter_map
3094            .assemble(parameters)
3095            .expect("parameter slice must match evaluator resources");
3096        let amplitude_len = self.amplitude_use_sites.len();
3097        let active_indices = resources.active_indices().to_vec();
3098        let slot_count = self.expression_value_slot_count();
3099        let program_snapshot = self.expression_value_program_snapshot();
3100        #[cfg(feature = "rayon")]
3101        {
3102            if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3103                return execution_context.install(|| {
3104                    resources
3105                        .caches
3106                        .par_iter()
3107                        .map_init(
3108                            || {
3109                                (
3110                                    vec![Complex64::ZERO; amplitude_len],
3111                                    vec![Complex64::ZERO; slot_count],
3112                                )
3113                            },
3114                            |(amplitude_values, expr_slots), cache| {
3115                                self.fill_amplitude_values(
3116                                    amplitude_values,
3117                                    &active_indices,
3118                                    &parameters,
3119                                    cache,
3120                                );
3121                                self.evaluate_expression_value_with_program_snapshot(
3122                                    &program_snapshot,
3123                                    amplitude_values,
3124                                    expr_slots,
3125                                )
3126                            },
3127                        )
3128                        .collect()
3129                });
3130            }
3131        }
3132        execution_context.with_scratch(|scratch| {
3133            let (amplitude_values, expr_slots) =
3134                scratch.reserve_value_workspaces(amplitude_len, slot_count);
3135            resources
3136                .caches
3137                .iter()
3138                .map(|cache| {
3139                    self.fill_amplitude_values(
3140                        amplitude_values,
3141                        &active_indices,
3142                        &parameters,
3143                        cache,
3144                    );
3145                    self.evaluate_expression_value_with_program_snapshot(
3146                        &program_snapshot,
3147                        amplitude_values,
3148                        expr_slots,
3149                    )
3150                })
3151                .collect()
3152        })
3153    }
3154
3155    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
3156    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
3157    ///
3158    /// # Notes
3159    ///
3160    /// This method is not intended to be called in analyses but rather in writing methods
3161    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
3162    #[cfg(feature = "mpi")]
3163    fn evaluate_mpi(
3164        &self,
3165        parameters: &[f64],
3166        world: &SimpleCommunicator,
3167    ) -> LadduResult<Vec<Complex64>> {
3168        let local_evaluation = self.evaluate_local(parameters)?;
3169        let n_events = self.dataset.n_events();
3170        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3171        let (counts, displs) = world.get_counts_displs(n_events);
3172        {
3173            // NOTE: gather is required here because the public MPI API returns full per-event outputs.
3174            // Do not replace with all-reduce unless semantics change to scalar aggregates only.
3175            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3176            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3177        }
3178        Ok(buffer)
3179    }
3180
3181    #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3182    fn evaluate_mpi_with_ctx(
3183        &self,
3184        parameters: &[f64],
3185        world: &SimpleCommunicator,
3186        execution_context: &ExecutionContext,
3187    ) -> Vec<Complex64> {
3188        let local_evaluation = self.evaluate_local_with_ctx(parameters, execution_context);
3189        let n_events = self.dataset.n_events();
3190        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3191        let (counts, displs) = world.get_counts_displs(n_events);
3192        {
3193            // NOTE: gather is required here because the public MPI API returns full per-event outputs.
3194            // Do not replace with all-reduce unless semantics change to scalar aggregates only.
3195            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3196            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3197        }
3198        buffer
3199    }
3200
3201    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
3202    /// [`Evaluator`] with the given values for free parameters.
3203    pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
3204        #[cfg(feature = "mpi")]
3205        {
3206            if let Some(world) = crate::mpi::get_world() {
3207                return self.evaluate_mpi(parameters, &world);
3208            }
3209        }
3210        self.evaluate_local(parameters)
3211    }
3212
3213    /// Evaluate the stored expression with a reusable execution context.
3214    ///
3215    /// This is intended for repeated calls with the same context instance.
3216    /// Thread behavior follows [`ThreadPolicy`](crate::ThreadPolicy) configured on
3217    /// [`ExecutionContext`](crate::ExecutionContext).
3218    #[cfg(feature = "execution-context-prototype")]
3219    pub fn evaluate_with_ctx(
3220        &self,
3221        parameters: &[f64],
3222        execution_context: &ExecutionContext,
3223    ) -> Vec<Complex64> {
3224        #[cfg(feature = "mpi")]
3225        {
3226            if let Some(world) = crate::mpi::get_world() {
3227                return self.evaluate_mpi_with_ctx(parameters, &world, execution_context);
3228            }
3229        }
3230        self.evaluate_local_with_ctx(parameters, execution_context)
3231    }
3232
3233    /// See [`Evaluator::evaluate_local`]. This method evaluates over a subset of events rather
3234    /// than all events in the total dataset.
3235    pub fn evaluate_batch_local(
3236        &self,
3237        parameters: &[f64],
3238        indices: &[usize],
3239    ) -> LadduResult<Vec<Complex64>> {
3240        let resources = self.resources.read();
3241        let parameters = resources.parameter_map.assemble(parameters)?;
3242        let amplitude_len = self.amplitude_use_sites.len();
3243        let active_indices = resources.active_indices().to_vec();
3244        let slot_count = self.expression_value_slot_count();
3245        let program_snapshot = self.expression_value_program_snapshot();
3246        #[cfg(feature = "rayon")]
3247        {
3248            Ok(indices
3249                .par_iter()
3250                .map_init(
3251                    || {
3252                        (
3253                            vec![Complex64::ZERO; amplitude_len],
3254                            vec![Complex64::ZERO; slot_count],
3255                        )
3256                    },
3257                    |(amplitude_values, expr_slots), &idx| {
3258                        let cache = &resources.caches[idx];
3259                        self.fill_amplitude_values(
3260                            amplitude_values,
3261                            &active_indices,
3262                            &parameters,
3263                            cache,
3264                        );
3265                        self.evaluate_expression_value_with_program_snapshot(
3266                            &program_snapshot,
3267                            amplitude_values,
3268                            expr_slots,
3269                        )
3270                    },
3271                )
3272                .collect())
3273        }
3274        #[cfg(not(feature = "rayon"))]
3275        {
3276            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3277            let mut expr_slots = vec![Complex64::ZERO; slot_count];
3278            Ok(indices
3279                .iter()
3280                .map(|&idx| {
3281                    let cache = &resources.caches[idx];
3282                    self.fill_amplitude_values(
3283                        &mut amplitude_values,
3284                        &active_indices,
3285                        &parameters,
3286                        cache,
3287                    );
3288                    self.evaluate_expression_value_with_program_snapshot(
3289                        &program_snapshot,
3290                        &amplitude_values,
3291                        &mut expr_slots,
3292                    )
3293                })
3294                .collect())
3295        }
3296    }
3297
3298    /// See [`Evaluator::evaluate_mpi`]. This method evaluates over a subset of events rather
3299    /// than all events in the total dataset.
3300    #[cfg(feature = "mpi")]
3301    fn evaluate_batch_mpi(
3302        &self,
3303        parameters: &[f64],
3304        indices: &[usize],
3305        world: &SimpleCommunicator,
3306    ) -> LadduResult<Vec<Complex64>> {
3307        let total = self.dataset.n_events();
3308        let locals = world.locals_from_globals(indices, total);
3309        let local_evaluation = self.evaluate_batch_local(parameters, &locals)?;
3310        Ok(world.all_gather_batched_partitioned(&local_evaluation, indices, total, None))
3311    }
3312
3313    /// Evaluate the stored [`Expression`] over a subset of events in the [`Dataset`] stored by the
3314    /// [`Evaluator`] with the given values for free parameters. See also [`Evaluator::evaluate`].
3315    pub fn evaluate_batch(
3316        &self,
3317        parameters: &[f64],
3318        indices: &[usize],
3319    ) -> LadduResult<Vec<Complex64>> {
3320        #[cfg(feature = "mpi")]
3321        {
3322            if let Some(world) = crate::mpi::get_world() {
3323                return self.evaluate_batch_mpi(parameters, indices, &world);
3324            }
3325        }
3326        self.evaluate_batch_local(parameters, indices)
3327    }
3328
3329    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
3330    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
3331    ///
3332    /// # Notes
3333    ///
3334    /// This method is not intended to be called in analyses but rather in writing methods
3335    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
3336    pub fn evaluate_gradient_local(
3337        &self,
3338        parameters: &[f64],
3339    ) -> LadduResult<Vec<DVector<Complex64>>> {
3340        let resources = self.resources.read();
3341        let parameters = resources.parameter_map.assemble(parameters)?;
3342        let amplitude_len = self.amplitude_use_sites.len();
3343        let grad_dim = parameters.len();
3344        let active_indices = resources.active_indices().to_vec();
3345        let lowered_runtime = self.lowered_runtime();
3346        let gradient_program = lowered_runtime.gradient_program();
3347        let slot_count = self.expression_gradient_slot_count();
3348        #[cfg(feature = "rayon")]
3349        {
3350            Ok(resources
3351                .caches
3352                .par_iter()
3353                .map_init(
3354                    || {
3355                        (
3356                            vec![Complex64::ZERO; amplitude_len],
3357                            vec![DVector::zeros(grad_dim); amplitude_len],
3358                            vec![Complex64::ZERO; slot_count],
3359                            vec![Complex64::ZERO; slot_count * grad_dim],
3360                        )
3361                    },
3362                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3363                        self.fill_amplitude_values_and_gradients(
3364                            amplitude_values,
3365                            gradient_values,
3366                            &active_indices,
3367                            &resources.active,
3368                            &parameters,
3369                            cache,
3370                        );
3371                        gradient_program.evaluate_gradient_into_flat(
3372                            amplitude_values,
3373                            gradient_values,
3374                            value_slots,
3375                            gradient_slots,
3376                            grad_dim,
3377                        )
3378                    },
3379                )
3380                .collect())
3381        }
3382        #[cfg(not(feature = "rayon"))]
3383        {
3384            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3385            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3386            let mut value_slots = vec![Complex64::ZERO; slot_count];
3387            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3388            Ok(resources
3389                .caches
3390                .iter()
3391                .map(|cache| {
3392                    self.fill_amplitude_values_and_gradients(
3393                        &mut amplitude_values,
3394                        &mut gradient_values,
3395                        &active_indices,
3396                        &resources.active,
3397                        &parameters,
3398                        cache,
3399                    );
3400                    gradient_program.evaluate_gradient_into_flat(
3401                        &amplitude_values,
3402                        &gradient_values,
3403                        &mut value_slots,
3404                        &mut gradient_slots,
3405                        grad_dim,
3406                    )
3407                })
3408                .collect())
3409        }
3410    }
3411
3412    /// Evaluate the gradient over local events using a reusable execution context.
3413    #[cfg(feature = "execution-context-prototype")]
3414    pub fn evaluate_gradient_local_with_ctx(
3415        &self,
3416        parameters: &[f64],
3417        execution_context: &ExecutionContext,
3418    ) -> Vec<DVector<Complex64>> {
3419        let resources = self.resources.read();
3420        let parameters = resources
3421            .parameter_map
3422            .assemble(parameters)
3423            .expect("parameter slice must match evaluator resources");
3424        let amplitude_len = self.amplitude_use_sites.len();
3425        let grad_dim = parameters.len();
3426        let active_indices = resources.active_indices().to_vec();
3427        let slot_count = self.expression_slot_count();
3428        #[cfg(feature = "rayon")]
3429        {
3430            if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3431                return execution_context.install(|| {
3432                    resources
3433                        .caches
3434                        .par_iter()
3435                        .map_init(
3436                            || {
3437                                (
3438                                    vec![Complex64::ZERO; amplitude_len],
3439                                    vec![DVector::zeros(grad_dim); amplitude_len],
3440                                    vec![Complex64::ZERO; slot_count],
3441                                    vec![DVector::zeros(grad_dim); slot_count],
3442                                )
3443                            },
3444                            |(amplitude_values, gradient_values, value_slots, gradient_slots),
3445                             cache| {
3446                                self.evaluate_cache_gradient_with_scratch(
3447                                    amplitude_values,
3448                                    gradient_values,
3449                                    value_slots,
3450                                    gradient_slots,
3451                                    &active_indices,
3452                                    &resources.active,
3453                                    &parameters,
3454                                    cache,
3455                                )
3456                            },
3457                        )
3458                        .collect()
3459                });
3460            }
3461        }
3462        execution_context.with_scratch(|scratch| {
3463            let (amplitude_values, value_slots, gradient_values, gradient_slots) =
3464                scratch.reserve_gradient_workspaces(amplitude_len, slot_count, grad_dim);
3465            resources
3466                .caches
3467                .iter()
3468                .map(|cache| {
3469                    self.evaluate_cache_gradient_with_scratch(
3470                        amplitude_values,
3471                        gradient_values,
3472                        value_slots,
3473                        gradient_slots,
3474                        &active_indices,
3475                        &resources.active,
3476                        &parameters,
3477                        cache,
3478                    )
3479                })
3480                .collect()
3481        })
3482    }
3483
3484    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
3485    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
3486    ///
3487    /// # Notes
3488    ///
3489    /// This method is not intended to be called in analyses but rather in writing methods
3490    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
3491    #[cfg(feature = "mpi")]
3492    fn evaluate_gradient_mpi(
3493        &self,
3494        parameters: &[f64],
3495        world: &SimpleCommunicator,
3496    ) -> LadduResult<Vec<DVector<Complex64>>> {
3497        let local_evaluation = self.evaluate_gradient_local(parameters)?;
3498        let n_events = self.dataset.n_events();
3499        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
3500        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
3501        {
3502            // NOTE: gather is required here because the public MPI API returns full per-event gradients.
3503            // Do not replace with all-reduce unless semantics change to aggregate-only outputs.
3504            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3505            world.all_gather_varcount_into(
3506                &local_evaluation
3507                    .iter()
3508                    .flat_map(|v| v.data.as_vec())
3509                    .copied()
3510                    .collect::<Vec<_>>(),
3511                &mut partitioned_buffer,
3512            );
3513        }
3514        Ok(buffer
3515            .chunks(parameters.len())
3516            .map(DVector::from_row_slice)
3517            .collect())
3518    }
3519
3520    #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3521    fn evaluate_gradient_mpi_with_ctx(
3522        &self,
3523        parameters: &[f64],
3524        world: &SimpleCommunicator,
3525        execution_context: &ExecutionContext,
3526    ) -> Vec<DVector<Complex64>> {
3527        let local_evaluation = self.evaluate_gradient_local_with_ctx(parameters, execution_context);
3528        let n_events = self.dataset.n_events();
3529        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
3530        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
3531        {
3532            // NOTE: gather is required here because the public MPI API returns full per-event gradients.
3533            // Do not replace with all-reduce unless semantics change to aggregate-only outputs.
3534            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3535            world.all_gather_varcount_into(
3536                &local_evaluation
3537                    .iter()
3538                    .flat_map(|v| v.data.as_vec())
3539                    .copied()
3540                    .collect::<Vec<_>>(),
3541                &mut partitioned_buffer,
3542            );
3543        }
3544        buffer
3545            .chunks(parameters.len())
3546            .map(DVector::from_row_slice)
3547            .collect()
3548    }
3549
3550    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
3551    /// [`Evaluator`] with the given values for free parameters.
3552    pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<Vec<DVector<Complex64>>> {
3553        #[cfg(feature = "mpi")]
3554        {
3555            if let Some(world) = crate::mpi::get_world() {
3556                return self.evaluate_gradient_mpi(parameters, &world);
3557            }
3558        }
3559        self.evaluate_gradient_local(parameters)
3560    }
3561
3562    /// Evaluate the gradient with a reusable execution context.
3563    ///
3564    /// This is intended for repeated calls with the same context instance.
3565    /// Thread behavior follows [`ThreadPolicy`](crate::ThreadPolicy) configured on
3566    /// [`ExecutionContext`](crate::ExecutionContext).
3567    #[cfg(feature = "execution-context-prototype")]
3568    pub fn evaluate_gradient_with_ctx(
3569        &self,
3570        parameters: &[f64],
3571        execution_context: &ExecutionContext,
3572    ) -> Vec<DVector<Complex64>> {
3573        #[cfg(feature = "mpi")]
3574        {
3575            if let Some(world) = crate::mpi::get_world() {
3576                return self.evaluate_gradient_mpi_with_ctx(parameters, &world, execution_context);
3577            }
3578        }
3579        self.evaluate_gradient_local_with_ctx(parameters, execution_context)
3580    }
3581
3582    /// See [`Evaluator::evaluate_gradient_local`]. This method evaluates over a subset
3583    /// of events rather than all events in the total dataset.
3584    pub fn evaluate_gradient_batch_local(
3585        &self,
3586        parameters: &[f64],
3587        indices: &[usize],
3588    ) -> LadduResult<Vec<DVector<Complex64>>> {
3589        let resources = self.resources.read();
3590        let parameters = resources.parameter_map.assemble(parameters)?;
3591        let amplitude_len = self.amplitude_use_sites.len();
3592        let grad_dim = parameters.len();
3593        let active_indices = resources.active_indices().to_vec();
3594        let lowered_runtime = self.lowered_runtime();
3595        let gradient_program = lowered_runtime.gradient_program();
3596        let slot_count = self.expression_gradient_slot_count();
3597        #[cfg(feature = "rayon")]
3598        {
3599            Ok(indices
3600                .par_iter()
3601                .map_init(
3602                    || {
3603                        (
3604                            vec![Complex64::ZERO; amplitude_len],
3605                            vec![DVector::zeros(grad_dim); amplitude_len],
3606                            vec![Complex64::ZERO; slot_count],
3607                            vec![Complex64::ZERO; slot_count * grad_dim],
3608                        )
3609                    },
3610                    |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
3611                        let cache = &resources.caches[idx];
3612                        self.fill_amplitude_values_and_gradients(
3613                            amplitude_values,
3614                            gradient_values,
3615                            &active_indices,
3616                            &resources.active,
3617                            &parameters,
3618                            cache,
3619                        );
3620                        gradient_program.evaluate_gradient_into_flat(
3621                            amplitude_values,
3622                            gradient_values,
3623                            value_slots,
3624                            gradient_slots,
3625                            grad_dim,
3626                        )
3627                    },
3628                )
3629                .collect())
3630        }
3631        #[cfg(not(feature = "rayon"))]
3632        {
3633            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3634            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3635            let mut value_slots = vec![Complex64::ZERO; slot_count];
3636            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3637            Ok(indices
3638                .iter()
3639                .map(|&idx| {
3640                    let cache = &resources.caches[idx];
3641                    self.fill_amplitude_values_and_gradients(
3642                        &mut amplitude_values,
3643                        &mut gradient_values,
3644                        &active_indices,
3645                        &resources.active,
3646                        &parameters,
3647                        cache,
3648                    );
3649                    gradient_program.evaluate_gradient_into_flat(
3650                        &amplitude_values,
3651                        &gradient_values,
3652                        &mut value_slots,
3653                        &mut gradient_slots,
3654                        grad_dim,
3655                    )
3656                })
3657                .collect())
3658        }
3659    }
3660
3661    /// See [`Evaluator::evaluate_gradient_mpi`]. This method evaluates over a subset
3662    /// of events rather than all events in the total dataset.
3663    #[cfg(feature = "mpi")]
3664    fn evaluate_gradient_batch_mpi(
3665        &self,
3666        parameters: &[f64],
3667        indices: &[usize],
3668        world: &SimpleCommunicator,
3669    ) -> LadduResult<Vec<DVector<Complex64>>> {
3670        let total = self.dataset.n_events();
3671        let locals = world.locals_from_globals(indices, total);
3672        let flattened_local_evaluation = self
3673            .evaluate_gradient_batch_local(parameters, &locals)?
3674            .iter()
3675            .flat_map(|g| g.data.as_vec().to_vec())
3676            .collect::<Vec<Complex64>>();
3677        Ok(world
3678            .all_gather_batched_partitioned(
3679                &flattened_local_evaluation,
3680                indices,
3681                total,
3682                Some(parameters.len()),
3683            )
3684            .chunks(parameters.len())
3685            .map(DVector::from_row_slice)
3686            .collect())
3687    }
3688
3689    /// Evaluate the gradient of the stored [`Expression`] over a subset of the
3690    /// events in the [`Dataset`] stored by the [`Evaluator`] with the given values
3691    /// for free parameters. See also [`Evaluator::evaluate_gradient`].
3692    pub fn evaluate_gradient_batch(
3693        &self,
3694        parameters: &[f64],
3695        indices: &[usize],
3696    ) -> LadduResult<Vec<DVector<Complex64>>> {
3697        #[cfg(feature = "mpi")]
3698        {
3699            if let Some(world) = crate::mpi::get_world() {
3700                return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
3701            }
3702        }
3703        self.evaluate_gradient_batch_local(parameters, indices)
3704    }
3705
3706    /// Evaluate the stored expression and its gradient over local events in one fused pass.
3707    pub fn evaluate_with_gradient_local(
3708        &self,
3709        parameters: &[f64],
3710    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3711        let resources = self.resources.read();
3712        let parameters = resources.parameter_map.assemble(parameters)?;
3713        let amplitude_len = self.amplitude_use_sites.len();
3714        let grad_dim = parameters.len();
3715        let active_indices = resources.active_indices().to_vec();
3716        let lowered_runtime = self.lowered_runtime();
3717        let value_gradient_program = lowered_runtime.value_gradient_program();
3718        let slot_count = self.expression_value_gradient_slot_count();
3719        #[cfg(feature = "rayon")]
3720        {
3721            Ok(resources
3722                .caches
3723                .par_iter()
3724                .map_init(
3725                    || {
3726                        (
3727                            vec![Complex64::ZERO; amplitude_len],
3728                            vec![DVector::zeros(grad_dim); amplitude_len],
3729                            vec![Complex64::ZERO; slot_count],
3730                            vec![Complex64::ZERO; slot_count * grad_dim],
3731                        )
3732                    },
3733                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3734                        self.fill_amplitude_values_and_gradients(
3735                            amplitude_values,
3736                            gradient_values,
3737                            &active_indices,
3738                            &resources.active,
3739                            &parameters,
3740                            cache,
3741                        );
3742                        value_gradient_program.evaluate_value_gradient_into_flat(
3743                            amplitude_values,
3744                            gradient_values,
3745                            value_slots,
3746                            gradient_slots,
3747                            grad_dim,
3748                        )
3749                    },
3750                )
3751                .collect())
3752        }
3753        #[cfg(not(feature = "rayon"))]
3754        {
3755            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3756            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3757            let mut value_slots = vec![Complex64::ZERO; slot_count];
3758            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3759            Ok(resources
3760                .caches
3761                .iter()
3762                .map(|cache| {
3763                    self.fill_amplitude_values_and_gradients(
3764                        &mut amplitude_values,
3765                        &mut gradient_values,
3766                        &active_indices,
3767                        &resources.active,
3768                        &parameters,
3769                        cache,
3770                    );
3771                    value_gradient_program.evaluate_value_gradient_into_flat(
3772                        &amplitude_values,
3773                        &gradient_values,
3774                        &mut value_slots,
3775                        &mut gradient_slots,
3776                        grad_dim,
3777                    )
3778                })
3779                .collect())
3780        }
3781    }
3782
3783    /// Evaluate local events and gradients with an explicit active-amplitude mask without mutating evaluator state.
3784    pub fn evaluate_with_gradient_local_with_active_mask(
3785        &self,
3786        parameters: &[f64],
3787        active_mask: &[bool],
3788    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3789        let resources = self.resources.read();
3790        if active_mask.len() != resources.active.len() {
3791            return Err(LadduError::LengthMismatch {
3792                context: "active amplitude mask".to_string(),
3793                expected: resources.active.len(),
3794                actual: active_mask.len(),
3795            });
3796        }
3797        let parameters = resources.parameter_map.assemble(parameters)?;
3798        let amplitude_len = self.amplitude_use_sites.len();
3799        let grad_dim = parameters.len();
3800        let active_indices = active_mask
3801            .iter()
3802            .enumerate()
3803            .filter_map(|(index, &active)| if active { Some(index) } else { None })
3804            .collect::<Vec<_>>();
3805        let lowered_runtime = self.lower_expression_runtime_for_active_mask(active_mask)?;
3806        let slot_count = lowered_runtime.value_gradient_program().scratch_slots();
3807        #[cfg(feature = "rayon")]
3808        {
3809            Ok(resources
3810                .caches
3811                .par_iter()
3812                .map_init(
3813                    || {
3814                        (
3815                            vec![Complex64::ZERO; amplitude_len],
3816                            vec![DVector::zeros(grad_dim); amplitude_len],
3817                            vec![Complex64::ZERO; slot_count],
3818                            vec![Complex64::ZERO; slot_count * grad_dim],
3819                        )
3820                    },
3821                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3822                        self.fill_amplitude_values_and_gradients(
3823                            amplitude_values,
3824                            gradient_values,
3825                            &active_indices,
3826                            active_mask,
3827                            &parameters,
3828                            cache,
3829                        );
3830                        lowered_runtime
3831                            .value_gradient_program()
3832                            .evaluate_value_gradient_into_flat(
3833                                amplitude_values,
3834                                gradient_values,
3835                                value_slots,
3836                                gradient_slots,
3837                                grad_dim,
3838                            )
3839                    },
3840                )
3841                .collect())
3842        }
3843        #[cfg(not(feature = "rayon"))]
3844        {
3845            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3846            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3847            let mut value_slots = vec![Complex64::ZERO; slot_count];
3848            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3849            Ok(resources
3850                .caches
3851                .iter()
3852                .map(|cache| {
3853                    self.fill_amplitude_values_and_gradients(
3854                        &mut amplitude_values,
3855                        &mut gradient_values,
3856                        &active_indices,
3857                        active_mask,
3858                        &parameters,
3859                        cache,
3860                    );
3861                    lowered_runtime
3862                        .value_gradient_program()
3863                        .evaluate_value_gradient_into_flat(
3864                            &amplitude_values,
3865                            &gradient_values,
3866                            &mut value_slots,
3867                            &mut gradient_slots,
3868                            grad_dim,
3869                        )
3870                })
3871                .collect())
3872        }
3873    }
3874
3875    /// Evaluate the stored expression and its gradient over a local subset of events in one fused pass.
3876    pub fn evaluate_with_gradient_batch_local(
3877        &self,
3878        parameters: &[f64],
3879        indices: &[usize],
3880    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3881        let resources = self.resources.read();
3882        let parameters = resources.parameter_map.assemble(parameters)?;
3883        let amplitude_len = self.amplitude_use_sites.len();
3884        let grad_dim = parameters.len();
3885        let active_indices = resources.active_indices().to_vec();
3886        let lowered_runtime = self.lowered_runtime();
3887        let value_gradient_program = lowered_runtime.value_gradient_program();
3888        let slot_count = self.expression_value_gradient_slot_count();
3889        #[cfg(feature = "rayon")]
3890        {
3891            Ok(indices
3892                .par_iter()
3893                .map_init(
3894                    || {
3895                        (
3896                            vec![Complex64::ZERO; amplitude_len],
3897                            vec![DVector::zeros(grad_dim); amplitude_len],
3898                            vec![Complex64::ZERO; slot_count],
3899                            vec![Complex64::ZERO; slot_count * grad_dim],
3900                        )
3901                    },
3902                    |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
3903                        let cache = &resources.caches[idx];
3904                        self.fill_amplitude_values_and_gradients(
3905                            amplitude_values,
3906                            gradient_values,
3907                            &active_indices,
3908                            &resources.active,
3909                            &parameters,
3910                            cache,
3911                        );
3912                        value_gradient_program.evaluate_value_gradient_into_flat(
3913                            amplitude_values,
3914                            gradient_values,
3915                            value_slots,
3916                            gradient_slots,
3917                            grad_dim,
3918                        )
3919                    },
3920                )
3921                .collect())
3922        }
3923        #[cfg(not(feature = "rayon"))]
3924        {
3925            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3926            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3927            let mut value_slots = vec![Complex64::ZERO; slot_count];
3928            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3929            Ok(indices
3930                .iter()
3931                .map(|&idx| {
3932                    let cache = &resources.caches[idx];
3933                    self.fill_amplitude_values_and_gradients(
3934                        &mut amplitude_values,
3935                        &mut gradient_values,
3936                        &active_indices,
3937                        &resources.active,
3938                        &parameters,
3939                        cache,
3940                    );
3941                    value_gradient_program.evaluate_value_gradient_into_flat(
3942                        &amplitude_values,
3943                        &gradient_values,
3944                        &mut value_slots,
3945                        &mut gradient_slots,
3946                        grad_dim,
3947                    )
3948                })
3949                .collect())
3950        }
3951    }
3952}
3953
3954#[cfg(test)]
3955mod tests {
3956    use approx::assert_relative_eq;
3957    #[cfg(feature = "mpi")]
3958    use mpi_test::mpi_test;
3959    use serde::{Deserialize, Serialize};
3960
3961    use super::*;
3962    use crate::{
3963        amplitude::{AmplitudeID, Tags, TestAmplitude},
3964        data::{test_dataset, test_event, DatasetMetadata, Event, EventData},
3965        parameter,
3966        parameters::Parameter,
3967        resources::{Cache, ParameterID, Parameters, Resources, ScalarID},
3968        vectors::Vec4,
3969    };
3970
3971    #[derive(Clone, Serialize, Deserialize)]
3972    pub struct ComplexScalar {
3973        name: String,
3974        re: Parameter,
3975        pid_re: ParameterID,
3976        im: Parameter,
3977        pid_im: ParameterID,
3978    }
3979
3980    impl ComplexScalar {
3981        #[allow(clippy::new_ret_no_self)]
3982        pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
3983            Self {
3984                name: name.to_string(),
3985                re,
3986                pid_re: Default::default(),
3987                im,
3988                pid_im: Default::default(),
3989            }
3990            .into_expression()
3991        }
3992    }
3993
3994    #[typetag::serde]
3995    impl Amplitude for ComplexScalar {
3996        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
3997            self.pid_re = resources.register_parameter(&self.re)?;
3998            self.pid_im = resources.register_parameter(&self.im)?;
3999            resources.register_amplitude(&self.name)
4000        }
4001
4002        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4003            Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
4004        }
4005
4006        fn compute_gradient(
4007            &self,
4008            parameters: &Parameters,
4009            _cache: &Cache,
4010            gradient: &mut DVector<Complex64>,
4011        ) {
4012            if let Some(ind) = parameters.free_index(self.pid_re) {
4013                gradient[ind] = Complex64::ONE;
4014            }
4015            if let Some(ind) = parameters.free_index(self.pid_im) {
4016                gradient[ind] = Complex64::I;
4017            }
4018        }
4019    }
4020
4021    #[derive(Clone, Serialize, Deserialize)]
4022    pub struct ParameterOnlyScalar {
4023        name: String,
4024        value: Parameter,
4025        pid: ParameterID,
4026    }
4027
4028    impl ParameterOnlyScalar {
4029        #[allow(clippy::new_ret_no_self)]
4030        pub fn new(name: &str, value: Parameter) -> LadduResult<Expression> {
4031            Self {
4032                name: name.to_string(),
4033                value,
4034                pid: Default::default(),
4035            }
4036            .into_expression()
4037        }
4038    }
4039
4040    #[typetag::serde]
4041    impl Amplitude for ParameterOnlyScalar {
4042        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4043            self.pid = resources.register_parameter(&self.value)?;
4044            resources.register_amplitude(&self.name)
4045        }
4046
4047        fn dependence_hint(&self) -> ExpressionDependence {
4048            ExpressionDependence::ParameterOnly
4049        }
4050
4051        fn real_valued_hint(&self) -> bool {
4052            true
4053        }
4054
4055        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4056            Complex64::new(parameters.get(self.pid), 0.0)
4057        }
4058    }
4059
4060    #[derive(Clone, Serialize, Deserialize)]
4061    pub struct CacheOnlyScalar {
4062        name: String,
4063        beam_energy: ScalarID,
4064    }
4065
4066    impl CacheOnlyScalar {
4067        #[allow(clippy::new_ret_no_self)]
4068        pub fn new(name: &str) -> LadduResult<Expression> {
4069            Self {
4070                name: name.to_string(),
4071                beam_energy: Default::default(),
4072            }
4073            .into_expression()
4074        }
4075    }
4076
4077    #[typetag::serde]
4078    impl Amplitude for CacheOnlyScalar {
4079        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4080            self.beam_energy =
4081                resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
4082            resources.register_amplitude(&self.name)
4083        }
4084
4085        fn dependence_hint(&self) -> ExpressionDependence {
4086            ExpressionDependence::CacheOnly
4087        }
4088
4089        fn real_valued_hint(&self) -> bool {
4090            true
4091        }
4092
4093        fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {
4094            cache.store_scalar(self.beam_energy, event.p4_at(0).e());
4095        }
4096
4097        fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
4098            Complex64::new(cache.get_scalar(self.beam_energy), 0.0)
4099        }
4100    }
4101
4102    #[derive(Clone, Copy)]
4103    enum DeterministicFixtureKind {
4104        Separable,
4105        Partial,
4106        NonSeparable,
4107    }
4108
4109    struct DeterministicFixture {
4110        expression: Expression,
4111        dataset: Arc<Dataset>,
4112        parameters: Vec<f64>,
4113    }
4114
4115    const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
4116    const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
4117
4118    fn deterministic_fixture_dataset() -> Arc<Dataset> {
4119        let metadata = Arc::new(DatasetMetadata::default());
4120        let events = vec![
4121            Arc::new(EventData {
4122                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
4123                aux: vec![],
4124                weight: 0.5,
4125            }),
4126            Arc::new(EventData {
4127                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
4128                aux: vec![],
4129                weight: -1.25,
4130            }),
4131            Arc::new(EventData {
4132                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
4133                aux: vec![],
4134                weight: 2.0,
4135            }),
4136            Arc::new(EventData {
4137                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
4138                aux: vec![],
4139                weight: 0.75,
4140            }),
4141        ];
4142        Arc::new(Dataset::new_with_metadata(events, metadata))
4143    }
4144
4145    fn make_deterministic_fixture(kind: DeterministicFixtureKind) -> DeterministicFixture {
4146        let dataset = deterministic_fixture_dataset();
4147        match kind {
4148            DeterministicFixtureKind::Separable => {
4149                let p1 = ParameterOnlyScalar::new("p1", parameter!("p1"))
4150                    .expect("separable p1 should build");
4151                let p2 = ParameterOnlyScalar::new("p2", parameter!("p2"))
4152                    .expect("separable p2 should build");
4153                let c1 = CacheOnlyScalar::new("c1").expect("separable c1 should build");
4154                let c2 = CacheOnlyScalar::new("c2").expect("separable c2 should build");
4155                DeterministicFixture {
4156                    expression: (&p1 * &c1) + &(&p2 * &c2),
4157                    dataset,
4158                    parameters: vec![0.4, -0.3],
4159                }
4160            }
4161            DeterministicFixtureKind::Partial => {
4162                let p =
4163                    ParameterOnlyScalar::new("p", parameter!("p")).expect("partial p should build");
4164                let c = CacheOnlyScalar::new("c").expect("partial c should build");
4165                let m = TestAmplitude::new("m", parameter!("mr"), parameter!("mi"))
4166                    .expect("partial m should build");
4167                DeterministicFixture {
4168                    expression: (&p * &c) + &m,
4169                    dataset,
4170                    parameters: vec![0.55, 0.2, -0.15],
4171                }
4172            }
4173            DeterministicFixtureKind::NonSeparable => {
4174                let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i"))
4175                    .expect("non-separable m1 should build");
4176                let m2 = TestAmplitude::new("m2", parameter!("m2r"), parameter!("m2i"))
4177                    .expect("non-separable m2 should build");
4178                DeterministicFixture {
4179                    expression: &m1 * &m2,
4180                    dataset,
4181                    parameters: vec![0.25, -0.4, 0.6, 0.1],
4182                }
4183            }
4184        }
4185    }
4186
4187    fn assert_weighted_sum_matches_eventwise_baseline(fixture: &DeterministicFixture) {
4188        let evaluator = fixture
4189            .expression
4190            .load(&fixture.dataset)
4191            .expect("fixture evaluator should load");
4192        let expected_value = evaluator
4193            .evaluate_local(&fixture.parameters)
4194            .expect("evaluation should succeed")
4195            .iter()
4196            .zip(fixture.dataset.weights_local().iter())
4197            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
4198        let expected_gradient = evaluator
4199            .evaluate_gradient_local(&fixture.parameters)
4200            .expect("evaluation should succeed")
4201            .iter()
4202            .zip(fixture.dataset.weights_local().iter())
4203            .fold(
4204                DVector::zeros(fixture.parameters.len()),
4205                |mut accum, (gradient, event)| {
4206                    accum += gradient.map(|value| value.re).scale(*event);
4207                    accum
4208                },
4209            );
4210        let actual_value = evaluator
4211            .evaluate_weighted_value_sum_local(&fixture.parameters)
4212            .expect("evaluation should succeed");
4213        let actual_gradient = evaluator
4214            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4215            .expect("evaluation should succeed");
4216        assert_relative_eq!(
4217            actual_value,
4218            expected_value,
4219            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4220            max_relative = DETERMINISTIC_STRICT_REL_TOL
4221        );
4222        for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
4223            assert_relative_eq!(
4224                *actual_item,
4225                *expected_item,
4226                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4227                max_relative = DETERMINISTIC_STRICT_REL_TOL
4228            );
4229        }
4230    }
4231    fn assert_mixed_normalization_components_match_combined_path(fixture: &DeterministicFixture) {
4232        let evaluator = fixture
4233            .expression
4234            .load(&fixture.dataset)
4235            .expect("fixture evaluator should load");
4236        let state = {
4237            let resources = evaluator.resources.read();
4238            evaluator.ensure_cached_integral_cache_state(&resources)
4239        }
4240        .expect("state should be available");
4241        assert!(
4242            !state.values.is_empty(),
4243            "fixture should exercise cached normalization terms"
4244        );
4245        assert!(
4246            !state.execution_sets.residual_amplitudes.is_empty(),
4247            "fixture should exercise residual normalization amplitudes"
4248        );
4249
4250        let (residual_value_sum, cached_value_sum) = evaluator
4251            .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4252            .expect("evaluation should succeed");
4253        assert!(residual_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4254        assert!(cached_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4255        let combined_value = evaluator
4256            .evaluate_weighted_value_sum_local(&fixture.parameters)
4257            .expect("evaluation should succeed");
4258        assert_relative_eq!(
4259            residual_value_sum + cached_value_sum,
4260            combined_value,
4261            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4262            max_relative = DETERMINISTIC_STRICT_REL_TOL
4263        );
4264
4265        let (residual_gradient_sum, cached_gradient_sum) = evaluator
4266            .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4267            .expect("evaluation should succeed");
4268        let combined_gradient = evaluator
4269            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4270            .expect("evaluation should succeed");
4271        assert!(residual_gradient_sum
4272            .iter()
4273            .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4274        assert!(cached_gradient_sum
4275            .iter()
4276            .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4277        for ((residual_item, cached_item), combined_item) in residual_gradient_sum
4278            .iter()
4279            .zip(cached_gradient_sum.iter())
4280            .zip(combined_gradient.iter())
4281        {
4282            assert_relative_eq!(
4283                residual_item + cached_item,
4284                *combined_item,
4285                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4286                max_relative = DETERMINISTIC_STRICT_REL_TOL
4287            );
4288        }
4289    }
4290
4291    #[test]
4292    fn test_deterministic_fixture_weighted_sums_stable_across_activation_mask_toggle() {
4293        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4294        let evaluator = fixture
4295            .expression
4296            .load(&fixture.dataset)
4297            .expect("fixture evaluator should load");
4298        let original_mask = evaluator.active_mask();
4299
4300        let original_value = evaluator
4301            .evaluate_weighted_value_sum_local(&fixture.parameters)
4302            .expect("evaluation should succeed");
4303
4304        evaluator.isolate_many(&["p", "c"]);
4305        assert_ne!(evaluator.active_mask(), original_mask);
4306
4307        evaluator
4308            .set_active_mask(&original_mask)
4309            .expect("original fixture active mask should restore");
4310        assert_eq!(evaluator.active_mask(), original_mask);
4311        let actual_value = evaluator
4312            .evaluate_weighted_value_sum_local(&fixture.parameters)
4313            .expect("evaluation should succeed");
4314        assert_relative_eq!(
4315            actual_value,
4316            original_value,
4317            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4318            max_relative = DETERMINISTIC_STRICT_REL_TOL
4319        );
4320    }
4321
4322    #[test]
4323    fn test_deterministic_fixtures_match_eventwise_weighted_sums() {
4324        let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4325        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4326        let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4327
4328        assert_weighted_sum_matches_eventwise_baseline(&separable);
4329        assert_weighted_sum_matches_eventwise_baseline(&partial);
4330        assert_weighted_sum_matches_eventwise_baseline(&non_separable);
4331    }
4332    #[test]
4333    fn test_deterministic_fixtures_cover_separable_partial_non_separable_models() {
4334        let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4335        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4336        let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4337
4338        let separable_evaluator = separable
4339            .expression
4340            .load(&separable.dataset)
4341            .expect("separable evaluator should load");
4342        let partial_evaluator = partial
4343            .expression
4344            .load(&partial.dataset)
4345            .expect("partial evaluator should load");
4346        let non_separable_evaluator = non_separable
4347            .expression
4348            .load(&non_separable.dataset)
4349            .expect("non-separable evaluator should load");
4350
4351        assert_eq!(
4352            separable_evaluator
4353                .expression_precomputed_cached_integrals()
4354                .expect("integrals should be computed")
4355                .len(),
4356            2
4357        );
4358        assert_eq!(
4359            partial_evaluator
4360                .expression_precomputed_cached_integrals()
4361                .expect("integrals should be computed")
4362                .len(),
4363            1
4364        );
4365        assert!(non_separable_evaluator
4366            .expression_precomputed_cached_integrals()
4367            .expect("integrals should be computed")
4368            .is_empty());
4369    }
4370    #[test]
4371    fn test_partial_fixture_combined_normalization_components_match_total() {
4372        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4373        assert_mixed_normalization_components_match_combined_path(&partial);
4374    }
4375    #[test]
4376    fn test_non_separable_fixture_normalization_components_stay_residual_only() {
4377        let fixture = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4378        let evaluator = fixture
4379            .expression
4380            .load(&fixture.dataset)
4381            .expect("fixture evaluator should load");
4382        let resources = evaluator.resources.read();
4383        let state = evaluator
4384            .ensure_cached_integral_cache_state(&resources)
4385            .expect("state should be available");
4386        assert!(state.values.is_empty());
4387
4388        let (residual_value_sum, cached_value_sum) = evaluator
4389            .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4390            .expect("evaluation should succeed");
4391        assert_relative_eq!(
4392            cached_value_sum,
4393            0.0,
4394            epsilon = DETERMINISTIC_STRICT_ABS_TOL
4395        );
4396        assert_relative_eq!(
4397            residual_value_sum,
4398            evaluator
4399                .evaluate_weighted_value_sum_local(&fixture.parameters)
4400                .expect("evaluation should succeed"),
4401            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4402            max_relative = DETERMINISTIC_STRICT_REL_TOL
4403        );
4404
4405        let (residual_gradient_sum, cached_gradient_sum) = evaluator
4406            .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4407            .expect("evaluation should succeed");
4408        assert!(cached_gradient_sum
4409            .iter()
4410            .all(|value| value.abs() <= DETERMINISTIC_STRICT_ABS_TOL));
4411        let combined_gradient = evaluator
4412            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4413            .expect("evaluation should succeed");
4414        for (residual_item, combined_item) in
4415            residual_gradient_sum.iter().zip(combined_gradient.iter())
4416        {
4417            assert_relative_eq!(
4418                *residual_item,
4419                *combined_item,
4420                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4421                max_relative = DETERMINISTIC_STRICT_REL_TOL
4422            );
4423        }
4424    }
4425
4426    #[test]
4427    fn test_batch_evaluation() {
4428        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag")).unwrap();
4429        let mut event1 = test_event();
4430        event1.p4s[0].t = 10.0;
4431        let mut event2 = test_event();
4432        event2.p4s[0].t = 11.0;
4433        let mut event3 = test_event();
4434        event3.p4s[0].t = 12.0;
4435        let dataset = Arc::new(Dataset::new_with_metadata(
4436            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
4437            Arc::new(DatasetMetadata::default()),
4438        ));
4439        let evaluator = expr.load(&dataset).unwrap();
4440        let result = evaluator
4441            .evaluate_batch(&[1.1, 2.2], &[0, 2])
4442            .expect("evaluation should succeed");
4443        assert_eq!(result.len(), 2);
4444        assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
4445        assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
4446        let result_grad = evaluator
4447            .evaluate_gradient_batch(&[1.1, 2.2], &[0, 2])
4448            .expect("evaluation should succeed");
4449        assert_eq!(result_grad.len(), 2);
4450        assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
4451        assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
4452        assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
4453        assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
4454    }
4455
4456    #[test]
4457    fn test_load_compiles_expression_ir_once() {
4458        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4459            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4460        .norm_sqr();
4461        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4462        let evaluator = expr.load(&dataset).unwrap();
4463        assert!(evaluator.expression_slot_count() > 0);
4464    }
4465    #[test]
4466    fn test_expression_ir_value_matches_lowered_runtime() {
4467        let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4468            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4469            * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
4470        .conj()
4471        .norm_sqr();
4472        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4473        let evaluator = expr.load(&dataset).unwrap();
4474        let resources = evaluator.resources.read();
4475        let parameters = resources
4476            .parameter_map
4477            .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
4478            .expect("parameters should assemble");
4479        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4480        evaluator.fill_amplitude_values(
4481            &mut amplitude_values,
4482            resources.active_indices(),
4483            &parameters,
4484            &resources.caches[0],
4485        );
4486        let mut ir_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4487        let lowered_runtime = evaluator.lowered_runtime();
4488        let lowered_program = lowered_runtime.value_program();
4489        let mut lowered_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4490        let lowered_value =
4491            evaluator.evaluate_expression_value_with_scratch(&amplitude_values, &mut ir_slots);
4492        let direct_lowered_value =
4493            lowered_program.evaluate_into(&amplitude_values, &mut lowered_slots);
4494        let ir_value = evaluator
4495            .expression_ir()
4496            .evaluate_into(&amplitude_values, &mut ir_slots);
4497        assert_relative_eq!(lowered_value.re, direct_lowered_value.re);
4498        assert_relative_eq!(lowered_value.im, direct_lowered_value.im);
4499        assert_relative_eq!(lowered_value.re, ir_value.re);
4500        assert_relative_eq!(lowered_value.im, ir_value.im);
4501    }
4502    #[test]
4503    fn test_expression_ir_load_initializes_with_lowered_value_runtime() {
4504        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai"))
4505            .unwrap()
4506            .norm_sqr();
4507        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4508        let evaluator = expr.load(&dataset).unwrap();
4509        let lowered_runtime = evaluator.lowered_runtime();
4510        assert_eq!(
4511            lowered_runtime.value_program().kind(),
4512            lowered::LoweredProgramKind::Value
4513        );
4514        assert_eq!(
4515            lowered_runtime.gradient_program().kind(),
4516            lowered::LoweredProgramKind::Gradient
4517        );
4518        assert_eq!(
4519            lowered_runtime.value_gradient_program().kind(),
4520            lowered::LoweredProgramKind::ValueGradient
4521        );
4522    }
4523    #[test]
4524    fn test_expression_ir_gradient_matches_lowered_runtime() {
4525        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4526            * TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4527        .norm_sqr();
4528        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4529        let evaluator = expr.load(&dataset).unwrap();
4530        let resources = evaluator.resources.read();
4531        let parameters = resources
4532            .parameter_map
4533            .assemble(&[1.0, 0.25, -0.8, 0.5])
4534            .expect("parameters should assemble");
4535        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4536        evaluator.fill_amplitude_values(
4537            &mut amplitude_values,
4538            resources.active_indices(),
4539            &parameters,
4540            &resources.caches[0],
4541        );
4542        let mut active_mask = vec![false; evaluator.amplitudes.len()];
4543        for &index in resources.active_indices() {
4544            active_mask[index] = true;
4545        }
4546        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
4547            .map(|_| DVector::zeros(parameters.len()))
4548            .collect::<Vec<_>>();
4549        evaluator.fill_amplitude_gradients(
4550            &mut amplitude_gradients,
4551            &active_mask,
4552            &parameters,
4553            &resources.caches[0],
4554        );
4555        let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4556        let mut ir_gradient_slots: Vec<DVector<Complex64>> =
4557            (0..evaluator.expression_ir().node_count())
4558                .map(|_| DVector::zeros(parameters.len()))
4559                .collect();
4560        let lowered_runtime = evaluator.lowered_runtime();
4561        let lowered_program = lowered_runtime.gradient_program();
4562        let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4563        let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
4564            .scratch_slots())
4565            .map(|_| DVector::zeros(parameters.len()))
4566            .collect();
4567        let active_gradient = evaluator.evaluate_expression_gradient_with_scratch(
4568            &amplitude_values,
4569            &amplitude_gradients,
4570            &mut ir_value_slots,
4571            &mut ir_gradient_slots,
4572        );
4573        let ir_gradient = evaluator.expression_ir().evaluate_gradient_into(
4574            &amplitude_values,
4575            &amplitude_gradients,
4576            &mut ir_value_slots,
4577            &mut ir_gradient_slots,
4578        );
4579        let lowered_gradient = lowered_program.evaluate_gradient_into(
4580            &amplitude_values,
4581            &amplitude_gradients,
4582            &mut lowered_value_slots,
4583            &mut lowered_gradient_slots,
4584        );
4585        for (active, lowered) in active_gradient.iter().zip(lowered_gradient.iter()) {
4586            assert_relative_eq!(active.re, lowered.re);
4587            assert_relative_eq!(active.im, lowered.im);
4588        }
4589        for (lowered, ir) in lowered_gradient.iter().zip(ir_gradient.iter()) {
4590            assert_relative_eq!(lowered.re, ir.re);
4591            assert_relative_eq!(lowered.im, ir.im);
4592        }
4593    }
4594    #[test]
4595    fn test_expression_ir_value_gradient_matches_lowered_runtime() {
4596        let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4597            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4598            * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
4599        .norm_sqr();
4600        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4601        let evaluator = expr.load(&dataset).unwrap();
4602        let resources = evaluator.resources.read();
4603        let parameters = resources
4604            .parameter_map
4605            .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
4606            .expect("parameters should assemble");
4607        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4608        evaluator.fill_amplitude_values(
4609            &mut amplitude_values,
4610            resources.active_indices(),
4611            &parameters,
4612            &resources.caches[0],
4613        );
4614        let mut active_mask = vec![false; evaluator.amplitudes.len()];
4615        for &index in resources.active_indices() {
4616            active_mask[index] = true;
4617        }
4618        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
4619            .map(|_| DVector::zeros(parameters.len()))
4620            .collect::<Vec<_>>();
4621        evaluator.fill_amplitude_gradients(
4622            &mut amplitude_gradients,
4623            &active_mask,
4624            &parameters,
4625            &resources.caches[0],
4626        );
4627        let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4628        let mut ir_gradient_slots: Vec<DVector<Complex64>> =
4629            (0..evaluator.expression_ir().node_count())
4630                .map(|_| DVector::zeros(parameters.len()))
4631                .collect();
4632        let lowered_runtime = evaluator.lowered_runtime();
4633        let lowered_program = lowered_runtime.value_gradient_program();
4634        let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4635        let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
4636            .scratch_slots())
4637            .map(|_| DVector::zeros(parameters.len()))
4638            .collect();
4639
4640        let active_value_gradient = evaluator.evaluate_expression_value_gradient_with_scratch(
4641            &amplitude_values,
4642            &amplitude_gradients,
4643            &mut ir_value_slots,
4644            &mut ir_gradient_slots,
4645        );
4646        let ir_value_gradient = evaluator.expression_ir().evaluate_value_gradient_into(
4647            &amplitude_values,
4648            &amplitude_gradients,
4649            &mut ir_value_slots,
4650            &mut ir_gradient_slots,
4651        );
4652        let lowered_value_gradient = lowered_program.evaluate_value_gradient_into(
4653            &amplitude_values,
4654            &amplitude_gradients,
4655            &mut lowered_value_slots,
4656            &mut lowered_gradient_slots,
4657        );
4658
4659        assert_relative_eq!(active_value_gradient.0.re, lowered_value_gradient.0.re);
4660        assert_relative_eq!(active_value_gradient.0.im, lowered_value_gradient.0.im);
4661        for (active, lowered) in active_value_gradient
4662            .1
4663            .iter()
4664            .zip(lowered_value_gradient.1.iter())
4665        {
4666            assert_relative_eq!(active.re, lowered.re);
4667            assert_relative_eq!(active.im, lowered.im);
4668        }
4669        assert_relative_eq!(lowered_value_gradient.0.re, ir_value_gradient.0.re);
4670        assert_relative_eq!(lowered_value_gradient.0.im, ir_value_gradient.0.im);
4671        for (lowered, ir) in lowered_value_gradient
4672            .1
4673            .iter()
4674            .zip(ir_value_gradient.1.iter())
4675        {
4676            assert_relative_eq!(lowered.re, ir.re);
4677            assert_relative_eq!(lowered.im, ir.im);
4678        }
4679    }
4680    #[test]
4681    fn test_expression_runtime_diagnostics_reports_lowered_programs() {
4682        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4683        let evaluator = fixture
4684            .expression
4685            .load(&fixture.dataset)
4686            .expect("fixture evaluator should load");
4687
4688        let diagnostics = evaluator.expression_runtime_diagnostics();
4689        assert!(diagnostics.ir_planning_enabled);
4690        assert!(diagnostics.lowered_value_program_present);
4691        assert!(diagnostics.lowered_gradient_program_present);
4692        assert!(diagnostics.lowered_value_gradient_program_present);
4693        assert!(diagnostics.residual_runtime_present);
4694        assert_eq!(
4695            diagnostics.specialization_status,
4696            Some(ExpressionSpecializationStatus {
4697                origin: ExpressionSpecializationOrigin::InitialLoad,
4698            })
4699        );
4700    }
4701    #[test]
4702    fn test_expression_runtime_diagnostics_reports_specialization_origin() {
4703        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4704        let evaluator = fixture
4705            .expression
4706            .load(&fixture.dataset)
4707            .expect("fixture evaluator should load");
4708
4709        assert_eq!(
4710            evaluator
4711                .expression_runtime_diagnostics()
4712                .specialization_status,
4713            Some(ExpressionSpecializationStatus {
4714                origin: ExpressionSpecializationOrigin::InitialLoad,
4715            })
4716        );
4717
4718        evaluator.isolate_many(&["p"]);
4719        assert_eq!(
4720            evaluator
4721                .expression_runtime_diagnostics()
4722                .specialization_status,
4723            Some(ExpressionSpecializationStatus {
4724                origin: ExpressionSpecializationOrigin::CacheMissRebuild,
4725            })
4726        );
4727    }
4728    #[test]
4729    fn test_compiled_expression_display_reports_dag_refs() {
4730        let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4731        let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4732        let term = &a * &b;
4733        let expr = &term + &term;
4734        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4735        let evaluator = expr.load(&dataset).unwrap();
4736
4737        let compiled = evaluator.compiled_expression();
4738        let display = compiled.to_string();
4739
4740        assert_eq!(compiled.root(), compiled.nodes().len() - 1);
4741        assert!(display.contains("#"));
4742        assert!(display.contains("+"));
4743        assert!(display.contains("×"));
4744        assert!(display.contains("a(id=0)"));
4745        assert!(display.contains("b(id=1)"));
4746        assert!(display.contains("(ref)"));
4747    }
4748
4749    #[test]
4750    fn test_expression_compiled_expression_display_reports_dag_refs_without_loading() {
4751        let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4752        let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4753        let term = &a * &b;
4754        let expr = &term + &term;
4755
4756        let compiled = expr.compiled_expression();
4757        let display = compiled.to_string();
4758
4759        assert_eq!(compiled.root(), compiled.nodes().len() - 1);
4760        assert!(display.contains("#"));
4761        assert!(display.contains("+"));
4762        assert!(display.contains("×"));
4763        assert!(display.contains("a(id=0)"));
4764        assert!(display.contains("b(id=1)"));
4765        assert!(display.contains("(ref)"));
4766    }
4767
4768    #[test]
4769    fn test_compiled_expression_display_uses_current_active_mask() {
4770        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4771            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4772        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4773        let evaluator = expr.load(&dataset).unwrap();
4774        evaluator.deactivate("b");
4775
4776        let compiled = evaluator.compiled_expression().to_string();
4777
4778        assert!(compiled.contains("a(id=0)"));
4779        assert!(!compiled.contains("b(id=1)"));
4780        assert!(!compiled.contains("const 0"));
4781        assert!(!compiled.contains("+"));
4782    }
4783
4784    fn assert_compiled_single_amplitude(expr: &Expression, expected_label: &str) {
4785        let compiled = expr.compiled_expression();
4786        assert_eq!(compiled.nodes().len(), 1);
4787        assert_eq!(compiled.root(), 0);
4788        match &compiled.nodes()[0] {
4789            CompiledExpressionNode::Amplitude { index, name } => {
4790                assert_eq!(*index, 0);
4791                assert_eq!(name, expected_label);
4792            }
4793            node => panic!("expected one amplitude node, got {node:?}"),
4794        }
4795    }
4796
4797    fn assert_compiled_constant(expr: &Expression, expected: Complex64) {
4798        let compiled = expr.compiled_expression();
4799        assert_eq!(compiled.nodes().len(), 1);
4800        assert_eq!(compiled.root(), 0);
4801        match compiled.nodes()[0] {
4802            CompiledExpressionNode::Constant(value) => {
4803                assert_relative_eq!(value.re, expected.re);
4804                assert_relative_eq!(value.im, expected.im);
4805            }
4806            ref node => panic!("expected one constant node, got {node:?}"),
4807        }
4808    }
4809
4810    #[test]
4811    fn test_compiled_expression_simplifies_arithmetic_identities() {
4812        let amp = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4813        let zero = Expression::zero();
4814        let one = Expression::one();
4815
4816        assert_compiled_single_amplitude(&(&amp + &zero), "a");
4817        assert_compiled_single_amplitude(&(&zero + &amp), "a");
4818        assert_compiled_single_amplitude(&(&amp - &zero), "a");
4819        assert_compiled_single_amplitude(&(&amp * &one), "a");
4820        assert_compiled_single_amplitude(&(&one * &amp), "a");
4821        assert_compiled_single_amplitude(&(&amp / &one), "a");
4822        assert_compiled_single_amplitude(&amp.pow(&one), "a");
4823        assert_compiled_single_amplitude(&amp.powi(1), "a");
4824        assert_compiled_single_amplitude(&amp.powf(1.0), "a");
4825
4826        let times_zero = &amp * &zero;
4827        assert_compiled_constant(&times_zero, Complex64::ZERO);
4828        assert!(times_zero.parameters().contains_key("ar"));
4829        assert!(times_zero.parameters().contains_key("ai"));
4830
4831        assert_compiled_constant(&(&zero * &amp), Complex64::ZERO);
4832        assert_compiled_constant(&(&zero / &Expression::from(2.0)), Complex64::ZERO);
4833        assert_compiled_constant(&amp.powi(0), Complex64::ONE);
4834        assert_compiled_constant(
4835            &Expression::from(2.0).pow(&Expression::zero()),
4836            Complex64::ONE,
4837        );
4838        assert_compiled_constant(&Expression::from(2.0).powf(0.0), Complex64::ONE);
4839
4840        let unsafe_zero_division = (&zero / &amp).compiled_expression().to_string();
4841        assert!(unsafe_zero_division.contains("÷"));
4842        assert!(unsafe_zero_division.contains("a(id=0)"));
4843    }
4844
4845    #[test]
4846    fn test_compiled_expression_folds_unary_constant_functions() {
4847        assert_compiled_constant(&Expression::from(0.0).exp(), Complex64::ONE);
4848        assert_compiled_constant(&Expression::from(0.0).sin(), Complex64::ZERO);
4849        assert_compiled_constant(&Expression::from(0.0).cos(), Complex64::ONE);
4850        assert_compiled_constant(&Expression::from(1.0).log(), Complex64::ZERO);
4851        assert_compiled_constant(&Expression::from(4.0).sqrt(), Complex64::new(2.0, 0.0));
4852        assert_compiled_constant(&Expression::from(0.0).cis(), Complex64::ONE);
4853    }
4854
4855    #[test]
4856    fn test_evaluator_expression_reconstructs_expression() {
4857        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4858        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4859        let evaluator = expr.load(&dataset).unwrap();
4860
4861        assert_eq!(
4862            evaluator.expression().compiled_expression(),
4863            expr.compiled_expression()
4864        );
4865    }
4866
4867    #[test]
4868    fn test_active_mask_override_ignores_current_ir_specialization() {
4869        let expr = ComplexScalar::new("amp", parameter!("scale"), parameter!("amp_im", 0.0))
4870            .unwrap()
4871            .norm_sqr();
4872        let dataset = Arc::new(test_dataset());
4873        let evaluator = expr.load(&dataset).unwrap();
4874        let params = vec![2.0];
4875
4876        evaluator.deactivate("amp");
4877        assert_eq!(
4878            evaluator
4879                .evaluate(&params)
4880                .expect("evaluation should succeed")[0],
4881            Complex64::new(0.0, 0.0)
4882        );
4883
4884        let overridden = evaluator
4885            .evaluate_local_with_active_mask(&params, &[true])
4886            .unwrap();
4887        assert_eq!(overridden[0], Complex64::new(4.0, 0.0));
4888
4889        let overridden_fused = evaluator
4890            .evaluate_with_gradient_local_with_active_mask(&params, &[true])
4891            .unwrap();
4892        assert_eq!(overridden_fused[0].0, Complex64::new(4.0, 0.0));
4893        assert_eq!(overridden_fused[0].1[0], Complex64::new(4.0, 0.0));
4894    }
4895    #[test]
4896    fn test_expression_ir_dependence_diagnostics_surface() {
4897        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4898            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4899        .norm_sqr();
4900        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4901        let evaluator = expr.load(&dataset).unwrap();
4902        let annotations = evaluator
4903            .expression_node_dependence_annotations()
4904            .expect("annotations should exist");
4905        assert_eq!(annotations.len(), evaluator.expression_ir().node_count());
4906        assert!(annotations
4907            .iter()
4908            .all(|dependence| *dependence == ExpressionDependence::Mixed));
4909        assert_eq!(
4910            evaluator
4911                .expression_root_dependence()
4912                .expect("root dependence should exist"),
4913            ExpressionDependence::Mixed
4914        );
4915    }
4916    #[test]
4917    fn test_expression_ir_default_dependence_hint_is_mixed() {
4918        let expr = ComplexScalar::new("c", parameter!("cr"), parameter!("ci")).unwrap();
4919        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4920        let evaluator = expr.load(&dataset).unwrap();
4921        assert_eq!(
4922            evaluator
4923                .expression_root_dependence()
4924                .expect("root dependence should exist"),
4925            ExpressionDependence::Mixed
4926        );
4927    }
4928    #[test]
4929    fn test_expression_ir_parameter_only_dependence_hint_propagates() {
4930        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap();
4931        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4932        let evaluator = expr.load(&dataset).unwrap();
4933        assert_eq!(
4934            evaluator
4935                .expression_root_dependence()
4936                .expect("root dependence should exist"),
4937            ExpressionDependence::ParameterOnly
4938        );
4939    }
4940    #[test]
4941    fn test_expression_ir_cache_only_dependence_hint_propagates() {
4942        let expr = CacheOnlyScalar::new("k").unwrap();
4943        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4944        let evaluator = expr.load(&dataset).unwrap();
4945        assert_eq!(
4946            evaluator
4947                .expression_root_dependence()
4948                .expect("root dependence should exist"),
4949            ExpressionDependence::CacheOnly
4950        );
4951    }
4952    #[test]
4953    fn test_expression_ir_real_valued_hint_folds_imag_projection_to_zero() {
4954        let expr = ParameterOnlyScalar::new("p", parameter!("p"))
4955            .unwrap()
4956            .imag();
4957        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4958        let evaluator = expr.load(&dataset).unwrap();
4959        let ir = evaluator.expression_ir();
4960
4961        assert!(matches!(
4962            ir.nodes()[ir.root()],
4963            ir::IrNode::Constant(value) if value == Complex64::ZERO
4964        ));
4965        assert_eq!(
4966            evaluator
4967                .evaluate(&[2.5])
4968                .expect("evaluation should succeed")[0],
4969            Complex64::ZERO
4970        );
4971    }
4972    #[test]
4973    fn test_expression_ir_real_valued_hint_simplifies_conjugation() {
4974        let expr = ParameterOnlyScalar::new("p", parameter!("p"))
4975            .unwrap()
4976            .conj();
4977        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4978        let evaluator = expr.load(&dataset).unwrap();
4979        let ir = evaluator.expression_ir();
4980
4981        assert!(matches!(ir.nodes()[ir.root()], ir::IrNode::Amp(0)));
4982        assert_eq!(
4983            evaluator
4984                .evaluate(&[2.5])
4985                .expect("evaluation should succeed")[0],
4986            Complex64::new(2.5, 0.0)
4987        );
4988    }
4989    #[test]
4990    fn test_expression_ir_dependence_warnings_surface() {
4991        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
4992            + &CacheOnlyScalar::new("k").unwrap();
4993        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4994        let evaluator = expr.load(&dataset).unwrap();
4995        assert!(evaluator
4996            .expression_dependence_warnings()
4997            .expect("warnings should exist")
4998            .iter()
4999            .any(|warning| warning.contains("both ParameterOnly and CacheOnly")));
5000    }
5001    #[test]
5002    fn test_expression_ir_normalization_plan_explain_surface() {
5003        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5004            * &CacheOnlyScalar::new("k").unwrap();
5005        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5006        let evaluator = expr.load(&dataset).unwrap();
5007        let explain = evaluator
5008            .expression_normalization_plan_explain()
5009            .expect("plan should exist");
5010        assert_eq!(explain.root_dependence, ExpressionDependence::Mixed);
5011        assert_eq!(explain.separable_mul_candidate_nodes.len(), 1);
5012        assert_eq!(
5013            explain.cached_separable_nodes,
5014            explain.separable_mul_candidate_nodes
5015        );
5016        assert!(explain.residual_terms.iter().all(|index| {
5017            !explain
5018                .separable_mul_candidate_nodes
5019                .iter()
5020                .any(|candidate| candidate == index)
5021        }));
5022    }
5023    #[test]
5024    fn test_expression_ir_normalization_execution_sets_surface() {
5025        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5026            * &CacheOnlyScalar::new("k").unwrap();
5027        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5028        let evaluator = expr.load(&dataset).unwrap();
5029        let sets = evaluator
5030            .expression_normalization_execution_sets()
5031            .expect("sets should exist");
5032        assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5033        assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5034        assert!(sets.residual_amplitudes.is_empty());
5035    }
5036    #[test]
5037    fn test_expression_ir_normalization_execution_sets_partial_surface() {
5038        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5039            * &CacheOnlyScalar::new("k").unwrap())
5040            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5041        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5042        let evaluator = expr.load(&dataset).unwrap();
5043        let sets = evaluator
5044            .expression_normalization_execution_sets()
5045            .expect("sets should exist");
5046        assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5047        assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5048        assert_eq!(sets.residual_amplitudes, vec![2]);
5049    }
5050    #[test]
5051    fn test_expression_ir_precomputed_cached_integrals_at_load() {
5052        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5053            * &CacheOnlyScalar::new("k").unwrap();
5054        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5055        let evaluator = expr.load(&dataset).unwrap();
5056        let precomputed = evaluator
5057            .expression_precomputed_cached_integrals()
5058            .expect("integrals should exist");
5059        assert_eq!(precomputed.len(), 1);
5060        let cache_reference = CacheOnlyScalar::new("k_ref")
5061            .unwrap()
5062            .load(&dataset)
5063            .unwrap();
5064        let cache_values = cache_reference
5065            .evaluate_local(&[])
5066            .expect("evaluation should succeed");
5067        let expected_weighted_sum = cache_values
5068            .iter()
5069            .zip(dataset.weights_local().iter())
5070            .fold(Complex64::ZERO, |acc, (value, event)| {
5071                acc + (*value * *event)
5072            });
5073        assert_relative_eq!(
5074            precomputed[0].weighted_cache_sum.re,
5075            expected_weighted_sum.re
5076        );
5077        assert_relative_eq!(
5078            precomputed[0].weighted_cache_sum.im,
5079            expected_weighted_sum.im
5080        );
5081    }
5082    #[test]
5083    fn test_expression_ir_precomputed_cached_integrals_empty_when_non_separable() {
5084        let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5085            * &CacheOnlyScalar::new("k").unwrap();
5086        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5087        let evaluator = expr.load(&dataset).unwrap();
5088        assert!(evaluator
5089            .expression_precomputed_cached_integrals()
5090            .expect("integrals should exist")
5091            .is_empty());
5092    }
5093    #[test]
5094    fn test_expression_ir_precomputed_cached_integrals_recompute_on_activation_change() {
5095        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5096            * &CacheOnlyScalar::new("k").unwrap();
5097        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5098        let evaluator = expr.load(&dataset).unwrap();
5099        assert_eq!(
5100            evaluator
5101                .expression_precomputed_cached_integrals()
5102                .expect("integrals should exist")
5103                .len(),
5104            1
5105        );
5106
5107        evaluator.isolate_many(&["p"]);
5108        assert!(evaluator
5109            .expression_precomputed_cached_integrals()
5110            .expect("integrals should exist")
5111            .is_empty());
5112    }
5113    #[test]
5114    fn test_expression_ir_precomputed_cached_integrals_recompute_on_dataset_change() {
5115        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5116            * &CacheOnlyScalar::new("k").unwrap();
5117        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5118        let mut evaluator = expr.load(&dataset).unwrap();
5119        drop(dataset);
5120        let before = evaluator
5121            .expression_precomputed_cached_integrals()
5122            .expect("integrals should exist");
5123        assert_eq!(before.len(), 1);
5124
5125        Arc::get_mut(&mut evaluator.dataset)
5126            .expect("evaluator should own dataset Arc in this test")
5127            .clear_events_local();
5128        let after = evaluator
5129            .expression_precomputed_cached_integrals()
5130            .expect("integrals should exist");
5131        assert_eq!(after.len(), 1);
5132        assert_eq!(after[0].weighted_cache_sum, Complex64::ZERO);
5133        assert!(before[0].weighted_cache_sum != after[0].weighted_cache_sum);
5134    }
5135    #[test]
5136    fn test_expression_ir_precomputed_cached_integral_gradient_terms_scale_by_cache_integrals() {
5137        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5138            * &CacheOnlyScalar::new("k").unwrap();
5139        let dataset = Arc::new(Dataset::new(vec![
5140            Arc::new(test_event()),
5141            Arc::new(test_event()),
5142        ]));
5143        let evaluator = expr.load(&dataset).unwrap();
5144        let cached_integrals = evaluator
5145            .expression_precomputed_cached_integrals()
5146            .expect("integrals should exist");
5147        assert_eq!(cached_integrals.len(), 1);
5148        let gradient_terms = evaluator
5149            .expression_precomputed_cached_integral_gradient_terms(&[1.25])
5150            .expect("evaluation should succeed");
5151        assert_eq!(gradient_terms.len(), 1);
5152        assert_eq!(gradient_terms[0].weighted_gradient.len(), 1);
5153        assert_relative_eq!(
5154            gradient_terms[0].weighted_gradient[0].re,
5155            cached_integrals[0].weighted_cache_sum.re,
5156            epsilon = 1e-6
5157        );
5158        assert_relative_eq!(
5159            gradient_terms[0].weighted_gradient[0].im,
5160            cached_integrals[0].weighted_cache_sum.im,
5161            epsilon = 1e-6
5162        );
5163    }
5164    #[test]
5165    fn test_expression_ir_precomputed_cached_integral_gradient_terms_empty_when_not_separable() {
5166        let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5167            * &CacheOnlyScalar::new("k").unwrap();
5168        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5169        let evaluator = expr.load(&dataset).unwrap();
5170        assert!(evaluator
5171            .expression_precomputed_cached_integral_gradient_terms(&[0.1, -0.2])
5172            .expect("evaluation should succeed")
5173            .is_empty());
5174    }
5175    #[test]
5176    fn test_expression_ir_lowered_cached_factor_programs_match_ir_cached_paths() {
5177        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5178            * &CacheOnlyScalar::new("k").unwrap())
5179            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5180        let dataset = Arc::new(test_dataset());
5181        let evaluator = expr.load(&dataset).unwrap();
5182        let resources = evaluator.resources.read();
5183        let state = evaluator
5184            .ensure_cached_integral_cache_state(&resources)
5185            .expect("state should be available");
5186        let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5187        let parameters = resources
5188            .parameter_map
5189            .assemble(&[0.55, 0.2, -0.15])
5190            .expect("parameters should assemble");
5191
5192        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5193        evaluator.fill_amplitude_values(
5194            &mut amplitude_values,
5195            &state.execution_sets.cached_parameter_amplitudes,
5196            &parameters,
5197            &resources.caches[0],
5198        );
5199        let cached_value_ir =
5200            evaluator.evaluate_cached_weighted_value_sum_ir(&state, &amplitude_values);
5201        let cached_value_lowered = evaluator
5202            .evaluate_cached_weighted_value_sum_lowered(
5203                &state,
5204                lowered_artifacts.as_ref(),
5205                &amplitude_values,
5206            )
5207            .expect("cached value lowering should succeed");
5208        assert_relative_eq!(cached_value_lowered, cached_value_ir, epsilon = 1e-12);
5209
5210        let mut cached_parameter_mask = vec![false; evaluator.amplitudes.len()];
5211        for &index in &state.execution_sets.cached_parameter_amplitudes {
5212            cached_parameter_mask[index] = true;
5213        }
5214        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5215            .map(|_| DVector::zeros(parameters.len()))
5216            .collect::<Vec<_>>();
5217        evaluator.fill_amplitude_gradients(
5218            &mut amplitude_gradients,
5219            &cached_parameter_mask,
5220            &parameters,
5221            &resources.caches[0],
5222        );
5223        let cached_gradient_ir = evaluator.evaluate_cached_weighted_gradient_sum_ir(
5224            &state,
5225            &amplitude_values,
5226            &amplitude_gradients,
5227            parameters.len(),
5228        );
5229        let cached_gradient_lowered = evaluator
5230            .evaluate_cached_weighted_gradient_sum_lowered(
5231                &state,
5232                lowered_artifacts.as_ref(),
5233                &amplitude_values,
5234                &amplitude_gradients,
5235                parameters.len(),
5236            )
5237            .expect("cached gradient lowering should succeed");
5238        for (lowered, ir) in cached_gradient_lowered
5239            .iter()
5240            .zip(cached_gradient_ir.iter())
5241        {
5242            assert_relative_eq!(*lowered, *ir, epsilon = 1e-12);
5243        }
5244    }
5245    #[test]
5246    fn test_expression_ir_lowered_residual_runtime_matches_zeroed_node_path() {
5247        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5248            * &CacheOnlyScalar::new("k").unwrap())
5249            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5250        let dataset = Arc::new(test_dataset());
5251        let evaluator = expr.load(&dataset).unwrap();
5252        let resources = evaluator.resources.read();
5253        let state = evaluator
5254            .ensure_cached_integral_cache_state(&resources)
5255            .expect("state should be available");
5256        let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5257        let parameters = resources
5258            .parameter_map
5259            .assemble(&[0.55, 0.2, -0.15])
5260            .expect("parameters should assemble");
5261
5262        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5263        evaluator.fill_amplitude_values(
5264            &mut amplitude_values,
5265            &state.execution_sets.residual_amplitudes,
5266            &parameters,
5267            &resources.caches[0],
5268        );
5269        let residual_value_ir = evaluator.evaluate_residual_value_ir(&state, &amplitude_values);
5270        let residual_program = lowered_artifacts
5271            .residual_runtime
5272            .as_ref()
5273            .map(|runtime| runtime.value_program())
5274            .expect("residual value lowering should succeed");
5275        let mut value_slots = vec![Complex64::ZERO; residual_program.scratch_slots()];
5276        let residual_value_lowered =
5277            residual_program.evaluate_into(&amplitude_values, &mut value_slots);
5278        assert_relative_eq!(
5279            residual_value_lowered.re,
5280            residual_value_ir.re,
5281            epsilon = 1e-12
5282        );
5283        assert_relative_eq!(
5284            residual_value_lowered.im,
5285            residual_value_ir.im,
5286            epsilon = 1e-12
5287        );
5288
5289        let mut residual_active_mask = vec![false; evaluator.amplitudes.len()];
5290        for &index in &state.execution_sets.residual_amplitudes {
5291            residual_active_mask[index] = true;
5292        }
5293        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5294            .map(|_| DVector::zeros(parameters.len()))
5295            .collect::<Vec<_>>();
5296        evaluator.fill_amplitude_gradients(
5297            &mut amplitude_gradients,
5298            &residual_active_mask,
5299            &parameters,
5300            &resources.caches[0],
5301        );
5302        let residual_gradient_ir = evaluator.evaluate_residual_gradient_ir(
5303            &state,
5304            &amplitude_values,
5305            &amplitude_gradients,
5306            parameters.len(),
5307        );
5308
5309        let program = lowered_artifacts
5310            .residual_runtime
5311            .as_ref()
5312            .map(|runtime| runtime.gradient_program())
5313            .expect("gradient lowering should succeed");
5314        let mut value_slots = vec![Complex64::ZERO; program.scratch_slots()];
5315        let mut gradient_slots = vec![Complex64::ZERO; program.scratch_slots() * parameters.len()];
5316        let residual_gradient_lowered = program.evaluate_gradient_into_flat(
5317            &amplitude_values,
5318            &amplitude_gradients,
5319            &mut value_slots,
5320            &mut gradient_slots,
5321            parameters.len(),
5322        );
5323
5324        for (lowered, ir) in residual_gradient_lowered
5325            .iter()
5326            .zip(residual_gradient_ir.iter())
5327        {
5328            assert_relative_eq!(lowered.re, ir.re, epsilon = 1e-12);
5329            assert_relative_eq!(lowered.im, ir.im, epsilon = 1e-12);
5330        }
5331    }
5332    #[test]
5333    fn test_expression_ir_reuses_lowered_artifacts_when_dataset_key_changes() {
5334        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5335            * &CacheOnlyScalar::new("k").unwrap())
5336            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5337        let dataset = Arc::new(test_dataset());
5338        let mut evaluator = expr.load(&dataset).unwrap();
5339        drop(dataset);
5340
5341        assert_eq!(evaluator.specialization_cache_len(), 1);
5342        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5343
5344        evaluator.reset_expression_compile_metrics();
5345        evaluator.reset_expression_specialization_metrics();
5346
5347        Arc::get_mut(&mut evaluator.dataset)
5348            .expect("evaluator should own dataset Arc in this test")
5349            .clear_events_local();
5350
5351        let cached_integrals = evaluator
5352            .expression_precomputed_cached_integrals()
5353            .expect("integrals should exist");
5354        assert_eq!(cached_integrals.len(), 1);
5355        assert_eq!(cached_integrals[0].weighted_cache_sum, Complex64::ZERO);
5356
5357        assert_eq!(evaluator.specialization_cache_len(), 2);
5358        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5359        assert_eq!(
5360            evaluator.expression_specialization_metrics(),
5361            ExpressionSpecializationMetrics {
5362                cache_hits: 0,
5363                cache_misses: 1,
5364            }
5365        );
5366
5367        let compile_metrics = evaluator.expression_compile_metrics();
5368        assert_eq!(compile_metrics.specialization_cache_hits, 0);
5369        assert_eq!(compile_metrics.specialization_cache_misses, 1);
5370        assert_eq!(compile_metrics.specialization_lowering_cache_hits, 1);
5371        assert_eq!(compile_metrics.specialization_lowering_cache_misses, 0);
5372        assert!(compile_metrics.specialization_ir_compile_nanos > 0);
5373        assert!(compile_metrics.specialization_cached_integrals_nanos > 0);
5374        assert_eq!(compile_metrics.specialization_lowering_nanos, 0);
5375    }
5376
5377    #[test]
5378    fn test_evaluate_weighted_gradient_sum_local_matches_eventwise_baseline() {
5379        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5380        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5381        let c1 = CacheOnlyScalar::new("c1").unwrap();
5382        let c2 = CacheOnlyScalar::new("c2").unwrap();
5383        let c3 = CacheOnlyScalar::new("c3").unwrap();
5384        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5385        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5386        let dataset = Arc::new(test_dataset());
5387        let evaluator = expr.load(&dataset).unwrap();
5388        assert_eq!(
5389            evaluator
5390                .expression_precomputed_cached_integrals()
5391                .expect("integrals should exist")
5392                .len(),
5393            2
5394        );
5395        let params = vec![0.2, -0.3, 1.1, -0.7];
5396        let expected = evaluator
5397            .evaluate_gradient_local(&params)
5398            .expect("evaluation should succeed")
5399            .iter()
5400            .zip(dataset.weights_local().iter())
5401            .fold(
5402                DVector::zeros(params.len()),
5403                |mut accum, (gradient, event)| {
5404                    accum += gradient.map(|value| value.re).scale(*event);
5405                    accum
5406                },
5407            );
5408        let actual = evaluator
5409            .evaluate_weighted_gradient_sum_local(&params)
5410            .expect("evaluation should succeed");
5411        for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
5412            assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
5413        }
5414    }
5415
5416    #[test]
5417    fn test_evaluate_weighted_value_sum_local_matches_eventwise_baseline() {
5418        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5419        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5420        let c1 = CacheOnlyScalar::new("c1").unwrap();
5421        let c2 = CacheOnlyScalar::new("c2").unwrap();
5422        let c3 = CacheOnlyScalar::new("c3").unwrap();
5423        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5424        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5425        let dataset = Arc::new(test_dataset());
5426        let evaluator = expr.load(&dataset).unwrap();
5427        assert_eq!(
5428            evaluator
5429                .expression_precomputed_cached_integrals()
5430                .expect("integrals should exist")
5431                .len(),
5432            2
5433        );
5434        let params = vec![0.2, -0.3, 1.1, -0.7];
5435        let expected = evaluator
5436            .evaluate_local(&params)
5437            .expect("evaluation should succeed")
5438            .iter()
5439            .zip(dataset.weights_local().iter())
5440            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5441        let actual = evaluator
5442            .evaluate_weighted_value_sum_local(&params)
5443            .expect("evaluation should succeed");
5444        assert_relative_eq!(actual, expected, epsilon = 1e-10);
5445    }
5446
5447    #[test]
5448    fn test_weighted_sums_match_hardcoded_reference_values() {
5449        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5450        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5451        let c1 = CacheOnlyScalar::new("c1").unwrap();
5452        let c2 = CacheOnlyScalar::new("c2").unwrap();
5453        let c3 = CacheOnlyScalar::new("c3").unwrap();
5454        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5455        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5456
5457        let metadata = Arc::new(DatasetMetadata::default());
5458        let dataset = Arc::new(Dataset::new_with_metadata(
5459            vec![
5460                Arc::new(EventData {
5461                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5462                    aux: vec![],
5463                    weight: 0.5,
5464                }),
5465                Arc::new(EventData {
5466                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5467                    aux: vec![],
5468                    weight: -1.25,
5469                }),
5470                Arc::new(EventData {
5471                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5472                    aux: vec![],
5473                    weight: 2.0,
5474                }),
5475            ],
5476            metadata,
5477        ));
5478        let evaluator = expr.load(&dataset).unwrap();
5479        let params = vec![0.7, -1.1, 0.9, -0.4];
5480
5481        let weighted_value_sum = evaluator
5482            .evaluate_weighted_value_sum_local(&params)
5483            .expect("evaluation should succeed");
5484        assert_relative_eq!(weighted_value_sum, 22.7725, epsilon = 1e-12);
5485
5486        let weighted_gradient_sum = evaluator
5487            .evaluate_weighted_gradient_sum_local(&params)
5488            .expect("evaluation should succeed");
5489        let free_parameters = evaluator
5490            .parameters()
5491            .free()
5492            .names()
5493            .into_iter()
5494            .map(|name| name.to_string())
5495            .collect::<Vec<_>>();
5496        assert_eq!(free_parameters, vec!["p1", "p2", "m1r", "m1i"]);
5497        let expected_gradient = [43.925, 7.25, 28.525, 0.0];
5498        assert_eq!(weighted_gradient_sum.len(), expected_gradient.len());
5499        for (actual, expected) in weighted_gradient_sum.iter().zip(expected_gradient.iter()) {
5500            assert_relative_eq!(*actual, *expected, epsilon = 1e-9);
5501        }
5502    }
5503    #[test]
5504    fn test_evaluate_weighted_gradient_sum_local_respects_signed_cached_terms() {
5505        let expr = Expression::one()
5506            - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5507                * &CacheOnlyScalar::new("k").unwrap());
5508        let dataset = Arc::new(test_dataset());
5509        let evaluator = expr.load(&dataset).unwrap();
5510        assert_eq!(
5511            evaluator
5512                .expression_precomputed_cached_integrals()
5513                .expect("integrals should exist")
5514                .len(),
5515            1
5516        );
5517        assert_eq!(
5518            evaluator
5519                .expression_precomputed_cached_integrals()
5520                .expect("integrals should exist")[0]
5521                .coefficient,
5522            -1
5523        );
5524        let params = vec![0.75];
5525        let expected = evaluator
5526            .evaluate_gradient_local(&params)
5527            .expect("evaluation should succeed")
5528            .iter()
5529            .zip(dataset.weights_local().iter())
5530            .fold(
5531                DVector::zeros(params.len()),
5532                |mut accum, (gradient, event)| {
5533                    accum += gradient.map(|value| value.re).scale(*event);
5534                    accum
5535                },
5536            );
5537        let actual = evaluator
5538            .evaluate_weighted_gradient_sum_local(&params)
5539            .expect("evaluation should succeed");
5540        for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
5541            assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
5542        }
5543    }
5544    #[test]
5545    fn test_evaluate_weighted_value_sum_local_respects_signed_cached_terms() {
5546        let expr = Expression::one()
5547            - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5548                * &CacheOnlyScalar::new("k").unwrap());
5549        let dataset = Arc::new(test_dataset());
5550        let evaluator = expr.load(&dataset).unwrap();
5551        assert_eq!(
5552            evaluator
5553                .expression_precomputed_cached_integrals()
5554                .expect("integrals should exist")
5555                .len(),
5556            1
5557        );
5558        assert_eq!(
5559            evaluator
5560                .expression_precomputed_cached_integrals()
5561                .expect("integrals should exist")[0]
5562                .coefficient,
5563            -1
5564        );
5565        let params = vec![0.75];
5566        let expected = evaluator
5567            .evaluate_local(&params)
5568            .expect("evaluation should succeed")
5569            .iter()
5570            .zip(dataset.weights_local().iter())
5571            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5572        let actual = evaluator
5573            .evaluate_weighted_value_sum_local(&params)
5574            .expect("evaluation should succeed");
5575        assert_relative_eq!(actual, expected, epsilon = 1e-10);
5576    }
5577    #[test]
5578    fn test_expression_ir_diagnostics_follow_activation_changes() {
5579        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5580            * &CacheOnlyScalar::new("k").unwrap())
5581            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5582        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5583        let evaluator = expr.load(&dataset).unwrap();
5584
5585        let all_active = evaluator
5586            .expression_normalization_plan_explain()
5587            .expect("plan should exist");
5588        assert_eq!(all_active.cached_separable_nodes.len(), 1);
5589        assert_eq!(
5590            evaluator
5591                .expression_root_dependence()
5592                .expect("root dependence should exist"),
5593            ExpressionDependence::Mixed
5594        );
5595
5596        evaluator.isolate_many(&["p"]);
5597        let param_only = evaluator
5598            .expression_normalization_plan_explain()
5599            .expect("plan should exist");
5600        assert!(param_only.cached_separable_nodes.is_empty());
5601        assert_eq!(
5602            evaluator
5603                .expression_root_dependence()
5604                .expect("root dependence should exist"),
5605            ExpressionDependence::ParameterOnly
5606        );
5607    }
5608    #[test]
5609    fn test_expression_ir_specialization_cache_reuses_prior_mask_specializations() {
5610        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5611            * &CacheOnlyScalar::new("k").unwrap())
5612            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5613        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5614        let evaluator = expr.load(&dataset).unwrap();
5615
5616        let initial_compile_metrics = evaluator.expression_compile_metrics();
5617        assert!(initial_compile_metrics.initial_ir_compile_nanos > 0);
5618        assert!(initial_compile_metrics.initial_cached_integrals_nanos > 0);
5619        assert!(initial_compile_metrics.initial_lowering_nanos > 0);
5620        assert_eq!(initial_compile_metrics.specialization_cache_hits, 0);
5621        assert_eq!(initial_compile_metrics.specialization_cache_misses, 0);
5622        assert_eq!(
5623            initial_compile_metrics.specialization_lowering_cache_hits,
5624            0
5625        );
5626        assert_eq!(
5627            initial_compile_metrics.specialization_lowering_cache_misses,
5628            1
5629        );
5630
5631        assert_eq!(evaluator.specialization_cache_len(), 1);
5632        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5633        assert_eq!(
5634            evaluator.expression_specialization_metrics(),
5635            ExpressionSpecializationMetrics {
5636                cache_hits: 0,
5637                cache_misses: 1,
5638            }
5639        );
5640        let all_active_cached_integrals = evaluator
5641            .expression_precomputed_cached_integrals()
5642            .expect("integrals should exist");
5643
5644        evaluator.isolate_many(&["p"]);
5645        assert_eq!(evaluator.specialization_cache_len(), 2);
5646        assert_eq!(
5647            evaluator.expression_specialization_metrics(),
5648            ExpressionSpecializationMetrics {
5649                cache_hits: 0,
5650                cache_misses: 2,
5651            }
5652        );
5653        let after_cache_miss_metrics = evaluator.expression_compile_metrics();
5654        assert_eq!(after_cache_miss_metrics.specialization_cache_hits, 0);
5655        assert_eq!(after_cache_miss_metrics.specialization_cache_misses, 1);
5656        assert_eq!(
5657            after_cache_miss_metrics.specialization_lowering_cache_hits,
5658            0
5659        );
5660        assert_eq!(
5661            after_cache_miss_metrics.specialization_lowering_cache_misses,
5662            2
5663        );
5664        assert!(after_cache_miss_metrics.specialization_ir_compile_nanos > 0);
5665        assert!(after_cache_miss_metrics.specialization_cached_integrals_nanos > 0);
5666        assert!(after_cache_miss_metrics.specialization_lowering_nanos > 0);
5667        assert!(evaluator
5668            .expression_precomputed_cached_integrals()
5669            .expect("integrals should exist")
5670            .is_empty());
5671
5672        evaluator.activate_many(&["k", "m"]);
5673        assert_eq!(evaluator.specialization_cache_len(), 2);
5674        assert_eq!(
5675            evaluator.expression_specialization_metrics(),
5676            ExpressionSpecializationMetrics {
5677                cache_hits: 1,
5678                cache_misses: 2,
5679            }
5680        );
5681        assert_eq!(
5682            evaluator
5683                .expression_precomputed_cached_integrals()
5684                .expect("integrals should exist"),
5685            all_active_cached_integrals
5686        );
5687        let after_cache_hit_metrics = evaluator.expression_compile_metrics();
5688        assert_eq!(after_cache_hit_metrics.specialization_cache_hits, 1);
5689        assert_eq!(after_cache_hit_metrics.specialization_cache_misses, 1);
5690        assert_eq!(
5691            after_cache_hit_metrics.specialization_lowering_cache_hits,
5692            0
5693        );
5694        assert_eq!(
5695            after_cache_hit_metrics.specialization_lowering_cache_misses,
5696            2
5697        );
5698        assert!(after_cache_hit_metrics.specialization_cache_restore_nanos > 0);
5699    }
5700
5701    #[test]
5702    fn test_weighted_sums_match_baseline_after_activation_changes() {
5703        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5704        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5705        let c1 = CacheOnlyScalar::new("c1").unwrap();
5706        let c2 = CacheOnlyScalar::new("c2").unwrap();
5707        let c3 = CacheOnlyScalar::new("c3").unwrap();
5708        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5709        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5710        let dataset = Arc::new(test_dataset());
5711        let evaluator = expr.load(&dataset).unwrap();
5712        let params = vec![0.2, -0.3, 1.1, -0.7];
5713
5714        evaluator.isolate_many(&["p1", "c1", "m1", "c3"]);
5715
5716        let expected_value = evaluator
5717            .evaluate_local(&params)
5718            .expect("evaluation should succeed")
5719            .iter()
5720            .zip(dataset.weights_local().iter())
5721            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5722        assert_relative_eq!(
5723            evaluator
5724                .evaluate_weighted_value_sum_local(&params)
5725                .expect("evaluation should succeed"),
5726            expected_value,
5727            epsilon = 1e-10
5728        );
5729    }
5730
5731    #[test]
5732    fn test_evaluate_local_does_not_depend_on_dataset_rows() {
5733        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5734            .unwrap()
5735            .norm_sqr();
5736        let mut event1 = test_event();
5737        event1.p4s[0].t = 7.5;
5738        let mut event2 = test_event();
5739        event2.p4s[0].t = 8.25;
5740        let mut event3 = test_event();
5741        event3.p4s[0].t = 9.0;
5742        let dataset = Arc::new(Dataset::new_with_metadata(
5743            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5744            Arc::new(DatasetMetadata::default()),
5745        ));
5746        let mut evaluator = expr.load(&dataset).unwrap();
5747        drop(dataset);
5748        let expected_len = evaluator.resources.read().caches.len();
5749        Arc::get_mut(&mut evaluator.dataset)
5750            .expect("evaluator should own dataset Arc in this test")
5751            .clear_events_local();
5752        let cached = evaluator
5753            .evaluate_local(&[1.25, -0.75])
5754            .expect("evaluation should succeed");
5755        assert_eq!(cached.len(), expected_len);
5756    }
5757
5758    #[test]
5759    fn test_evaluate_gradient_local_does_not_depend_on_dataset_rows() {
5760        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5761            .unwrap()
5762            .norm_sqr();
5763        let mut event1 = test_event();
5764        event1.p4s[0].t = 7.5;
5765        let mut event2 = test_event();
5766        event2.p4s[0].t = 8.25;
5767        let mut event3 = test_event();
5768        event3.p4s[0].t = 9.0;
5769        let dataset = Arc::new(Dataset::new_with_metadata(
5770            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5771            Arc::new(DatasetMetadata::default()),
5772        ));
5773        let mut evaluator = expr.load(&dataset).unwrap();
5774        drop(dataset);
5775        let expected_len = evaluator.resources.read().caches.len();
5776        Arc::get_mut(&mut evaluator.dataset)
5777            .expect("evaluator should own dataset Arc in this test")
5778            .clear_events_local();
5779        let cached = evaluator
5780            .evaluate_gradient_local(&[1.25, -0.75])
5781            .expect("evaluation should succeed");
5782        assert_eq!(cached.len(), expected_len);
5783    }
5784
5785    #[test]
5786    fn test_evaluate_with_gradient_local_matches_separate_paths() {
5787        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5788            .unwrap()
5789            .norm_sqr();
5790        let dataset = Arc::new(Dataset::new(vec![
5791            Arc::new(test_event()),
5792            Arc::new(test_event()),
5793            Arc::new(test_event()),
5794        ]));
5795        let evaluator = expr.load(&dataset).unwrap();
5796        let params = [1.25, -0.75];
5797        let values = evaluator
5798            .evaluate_local(&params)
5799            .expect("evaluation should succeed");
5800        let gradients = evaluator
5801            .evaluate_gradient_local(&params)
5802            .expect("evaluation should succeed");
5803        let fused = evaluator
5804            .evaluate_with_gradient_local(&params)
5805            .expect("evaluation should succeed");
5806        assert_eq!(fused.len(), values.len());
5807        assert_eq!(fused.len(), gradients.len());
5808        for ((value_gradient, value), gradient) in
5809            fused.iter().zip(values.iter()).zip(gradients.iter())
5810        {
5811            let (fused_value, fused_gradient) = value_gradient;
5812            assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
5813            assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
5814            assert_eq!(fused_gradient.len(), gradient.len());
5815            for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
5816                assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
5817                assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
5818            }
5819        }
5820    }
5821
5822    #[test]
5823    fn test_evaluate_with_gradient_batch_local_matches_separate_paths() {
5824        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5825            .unwrap()
5826            .norm_sqr();
5827        let dataset = Arc::new(Dataset::new(vec![
5828            Arc::new(test_event()),
5829            Arc::new(test_event()),
5830            Arc::new(test_event()),
5831            Arc::new(test_event()),
5832        ]));
5833        let evaluator = expr.load(&dataset).unwrap();
5834        let params = [0.5, -1.25];
5835        let indices = vec![0, 2, 3];
5836        let values = evaluator
5837            .evaluate_batch_local(&params, &indices)
5838            .expect("evaluation should succeed");
5839        let gradients = evaluator
5840            .evaluate_gradient_batch_local(&params, &indices)
5841            .expect("evaluation should succeed");
5842        let fused = evaluator
5843            .evaluate_with_gradient_batch_local(&params, &indices)
5844            .expect("evaluation should succeed");
5845        assert_eq!(fused.len(), values.len());
5846        assert_eq!(fused.len(), gradients.len());
5847        for ((value_gradient, value), gradient) in
5848            fused.iter().zip(values.iter()).zip(gradients.iter())
5849        {
5850            let (fused_value, fused_gradient) = value_gradient;
5851            assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
5852            assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
5853            assert_eq!(fused_gradient.len(), gradient.len());
5854            for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
5855                assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
5856                assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
5857            }
5858        }
5859    }
5860
5861    #[test]
5862    fn test_precompute_all_columnar_populates_cache() {
5863        let mut event1 = test_event();
5864        event1.p4s[0].t = 7.5;
5865        let mut event2 = test_event();
5866        event2.p4s[0].t = 8.25;
5867        let mut event3 = test_event();
5868        event3.p4s[0].t = 9.0;
5869        let dataset = Dataset::new_with_metadata(
5870            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5871            Arc::new(DatasetMetadata::default()),
5872        );
5873        let mut amplitude = TestAmplitude {
5874            tags: Tags::new(["test"]),
5875            re: parameter!("real"),
5876            pid_re: ParameterID::default(),
5877            im: parameter!("imag"),
5878            pid_im: ParameterID::default(),
5879            beam_energy: Default::default(),
5880        };
5881        let mut resources = Resources::default();
5882        amplitude
5883            .register(&mut resources)
5884            .expect("test amplitude should register");
5885        resources.reserve_cache(dataset.n_events());
5886        amplitude.precompute_all(&dataset, &mut resources);
5887        for cache in &resources.caches {
5888            assert!(cache.get_scalar(amplitude.beam_energy) > 0.0);
5889        }
5890    }
5891
5892    #[cfg(feature = "mpi")]
5893    #[mpi_test(np = [2])]
5894    fn test_load_reserves_local_cache_size_in_mpi() {
5895        use crate::mpi::{finalize_mpi, get_world, use_mpi};
5896
5897        use_mpi(true);
5898        assert!(get_world().is_some(), "MPI world should be initialized");
5899
5900        let expr = ComplexScalar::new(
5901            "constant",
5902            parameter!("const_re", 2.0),
5903            parameter!("const_im", 3.0),
5904        )
5905        .expect("constant amplitude should construct");
5906        let events = vec![
5907            Arc::new(test_event()),
5908            Arc::new(test_event()),
5909            Arc::new(test_event()),
5910            Arc::new(test_event()),
5911        ];
5912        let dataset = Arc::new(Dataset::new_with_metadata(
5913            events,
5914            Arc::new(DatasetMetadata::default()),
5915        ));
5916        let evaluator = expr.load(&dataset).expect("evaluator should load");
5917        let local_events = dataset.n_events_local();
5918        let cache_len = evaluator.resources.read().caches.len();
5919
5920        assert_eq!(
5921            cache_len, local_events,
5922            "cache length must match local event count under MPI"
5923        );
5924        finalize_mpi();
5925    }
5926
5927    #[cfg(feature = "mpi")]
5928    #[mpi_test(np = [2])]
5929    fn test_expression_ir_cached_integrals_are_rank_local_in_mpi() {
5930        use mpi::{collective::SystemOperation, topology::Communicator, traits::*};
5931
5932        use crate::mpi::{finalize_mpi, get_world, use_mpi};
5933
5934        use_mpi(true);
5935        let world = get_world().expect("MPI world should be initialized");
5936
5937        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5938            * &CacheOnlyScalar::new("k").unwrap();
5939        let events = vec![
5940            Arc::new(EventData {
5941                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
5942                aux: vec![],
5943                weight: 0.5,
5944            }),
5945            Arc::new(EventData {
5946                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5947                aux: vec![],
5948                weight: 1.0,
5949            }),
5950            Arc::new(EventData {
5951                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5952                aux: vec![],
5953                weight: 1.5,
5954            }),
5955            Arc::new(EventData {
5956                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
5957                aux: vec![],
5958                weight: 2.0,
5959            }),
5960            Arc::new(EventData {
5961                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5962                aux: vec![],
5963                weight: 2.5,
5964            }),
5965            Arc::new(EventData {
5966                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
5967                aux: vec![],
5968                weight: 3.0,
5969            }),
5970            Arc::new(EventData {
5971                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
5972                aux: vec![],
5973                weight: 3.5,
5974            }),
5975            Arc::new(EventData {
5976                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
5977                aux: vec![],
5978                weight: 4.0,
5979            }),
5980        ];
5981        let dataset = Arc::new(Dataset::new_with_metadata(
5982            events,
5983            Arc::new(DatasetMetadata::default()),
5984        ));
5985        let evaluator = expr.load(&dataset).expect("evaluator should load");
5986        let cached_integrals = evaluator
5987            .expression_precomputed_cached_integrals()
5988            .expect("integrals should exist");
5989        assert_eq!(cached_integrals.len(), 1);
5990
5991        let local_expected =
5992            dataset
5993                .weights_local()
5994                .iter()
5995                .enumerate()
5996                .fold(0.0, |acc, (index, weight)| {
5997                    let event = dataset.event_local(index).expect("event should exist");
5998                    acc + *weight * event.p4_at(0).e()
5999                });
6000        let cached_local = cached_integrals[0].weighted_cache_sum;
6001        assert_relative_eq!(cached_local.re, local_expected, epsilon = 1e-12);
6002        assert_relative_eq!(cached_local.im, 0.0, epsilon = 1e-12);
6003
6004        let weighted_value_sum = evaluator
6005            .evaluate_weighted_value_sum_local(&[2.0])
6006            .expect("evaluate should succeed");
6007        assert_relative_eq!(weighted_value_sum, 2.0 * local_expected, epsilon = 1e-10);
6008
6009        let mut global_expected = 0.0;
6010        world.all_reduce_into(
6011            &local_expected,
6012            &mut global_expected,
6013            SystemOperation::sum(),
6014        );
6015        if world.size() > 1 {
6016            assert!(
6017                (cached_local.re - global_expected).abs() > 1e-12,
6018                "cached integral should remain rank-local before MPI reduction"
6019            );
6020        }
6021        finalize_mpi();
6022    }
6023
6024    #[cfg(feature = "mpi")]
6025    #[mpi_test(np = [2])]
6026    fn test_expression_ir_weighted_sum_mpi_matches_global_eventwise_baseline() {
6027        use mpi::{collective::SystemOperation, traits::*};
6028
6029        use crate::mpi::{finalize_mpi, get_world, use_mpi};
6030
6031        use_mpi(true);
6032        let world = get_world().expect("MPI world should be initialized");
6033
6034        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6035        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6036        let c1 = CacheOnlyScalar::new("c1").unwrap();
6037        let c2 = CacheOnlyScalar::new("c2").unwrap();
6038        let c3 = CacheOnlyScalar::new("c3").unwrap();
6039        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6040        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6041        let events = vec![
6042            Arc::new(EventData {
6043                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
6044                aux: vec![],
6045                weight: 0.5,
6046            }),
6047            Arc::new(EventData {
6048                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6049                aux: vec![],
6050                weight: -1.25,
6051            }),
6052            Arc::new(EventData {
6053                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6054                aux: vec![],
6055                weight: 0.75,
6056            }),
6057            Arc::new(EventData {
6058                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
6059                aux: vec![],
6060                weight: 1.5,
6061            }),
6062            Arc::new(EventData {
6063                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6064                aux: vec![],
6065                weight: 2.25,
6066            }),
6067            Arc::new(EventData {
6068                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
6069                aux: vec![],
6070                weight: -0.5,
6071            }),
6072            Arc::new(EventData {
6073                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
6074                aux: vec![],
6075                weight: 3.5,
6076            }),
6077            Arc::new(EventData {
6078                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
6079                aux: vec![],
6080                weight: 1.25,
6081            }),
6082        ];
6083        let dataset = Arc::new(Dataset::new_with_metadata(
6084            events,
6085            Arc::new(DatasetMetadata::default()),
6086        ));
6087        let evaluator = expr.load(&dataset).expect("evaluator should load");
6088        let params = vec![0.2, -0.3, 1.1, -0.7];
6089
6090        let local_expected_value = evaluator
6091            .evaluate_local(&params)
6092            .expect("evaluate should succeed")
6093            .iter()
6094            .zip(dataset.weights_local().iter())
6095            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6096        let mut global_expected_value = 0.0;
6097        world.all_reduce_into(
6098            &local_expected_value,
6099            &mut global_expected_value,
6100            SystemOperation::sum(),
6101        );
6102        let mpi_value = evaluator
6103            .evaluate_weighted_value_sum_mpi(&params, &world)
6104            .expect("evaluate should succeed");
6105        assert_relative_eq!(mpi_value, global_expected_value, epsilon = 1e-10);
6106
6107        let local_expected_gradient = evaluator
6108            .evaluate_gradient_local(&params)
6109            .expect("evaluate should succeed")
6110            .iter()
6111            .zip(dataset.weights_local().iter())
6112            .fold(
6113                DVector::zeros(params.len()),
6114                |mut accum, (gradient, event)| {
6115                    accum += gradient.map(|value| value.re).scale(*event);
6116                    accum
6117                },
6118            );
6119        let mut global_expected_gradient = vec![0.0; local_expected_gradient.len()];
6120        world.all_reduce_into(
6121            local_expected_gradient.as_slice(),
6122            &mut global_expected_gradient,
6123            SystemOperation::sum(),
6124        );
6125        let mpi_gradient = evaluator
6126            .evaluate_weighted_gradient_sum_mpi(&params, &world)
6127            .expect("evaluate should succeed");
6128        for (actual, expected) in mpi_gradient.iter().zip(global_expected_gradient.iter()) {
6129            assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
6130        }
6131
6132        finalize_mpi();
6133    }
6134
6135    #[test]
6136    fn test_evaluate_local_succeeds_for_constant_amplitude() {
6137        let expr = ComplexScalar::new(
6138            "constant",
6139            parameter!("const_re", 2.0),
6140            parameter!("const_im", 3.0),
6141        )
6142        .unwrap();
6143        let dataset = Arc::new(Dataset::new_with_metadata(
6144            vec![Arc::new(test_event())],
6145            Arc::new(DatasetMetadata::default()),
6146        ));
6147        let evaluator = expr.load(&dataset).unwrap();
6148        let values = evaluator
6149            .evaluate_local(&[])
6150            .expect("evaluation should succeed");
6151        assert_eq!(values.len(), 1);
6152        let gradients = evaluator
6153            .evaluate_gradient_local(&[])
6154            .expect("evaluation should succeed");
6155        assert_eq!(gradients.len(), 1);
6156    }
6157
6158    #[test]
6159    fn test_constant_amplitude() {
6160        let expr = ComplexScalar::new(
6161            "constant",
6162            parameter!("const_re", 2.0),
6163            parameter!("const_im", 3.0),
6164        )
6165        .unwrap();
6166        let dataset = Arc::new(Dataset::new_with_metadata(
6167            vec![Arc::new(test_event())],
6168            Arc::new(DatasetMetadata::default()),
6169        ));
6170        let evaluator = expr.load(&dataset).unwrap();
6171        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6172        assert_eq!(result[0], Complex64::new(2.0, 3.0));
6173    }
6174
6175    #[test]
6176    fn test_parametric_amplitude() {
6177        let expr = ComplexScalar::new(
6178            "parametric",
6179            parameter!("test_param_re"),
6180            parameter!("test_param_im"),
6181        )
6182        .unwrap();
6183        let dataset = Arc::new(test_dataset());
6184        let evaluator = expr.load(&dataset).unwrap();
6185        let result = evaluator
6186            .evaluate(&[2.0, 3.0])
6187            .expect("evaluation should succeed");
6188        assert_eq!(result[0], Complex64::new(2.0, 3.0));
6189    }
6190
6191    #[test]
6192    fn test_expression_operations() {
6193        let expr1 = ComplexScalar::new(
6194            "const1",
6195            parameter!("const1_re", 2.0),
6196            parameter!("const1_im", 0.0),
6197        )
6198        .unwrap();
6199        let expr2 = ComplexScalar::new(
6200            "const2",
6201            parameter!("const2_re", 0.0),
6202            parameter!("const2_im", 1.0),
6203        )
6204        .unwrap();
6205        let expr3 = ComplexScalar::new(
6206            "const3",
6207            parameter!("const3_re", 3.0),
6208            parameter!("const3_im", 4.0),
6209        )
6210        .unwrap();
6211
6212        let dataset = Arc::new(test_dataset());
6213
6214        // Test (amp) addition
6215        let expr_add = &expr1 + &expr2;
6216        let result_add = expr_add
6217            .load(&dataset)
6218            .unwrap()
6219            .evaluate(&[])
6220            .expect("evaluation should succeed");
6221        assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
6222
6223        // Test (amp) subtraction
6224        let expr_sub = &expr1 - &expr2;
6225        let result_sub = expr_sub
6226            .load(&dataset)
6227            .unwrap()
6228            .evaluate(&[])
6229            .expect("evaluation should succeed");
6230        assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
6231
6232        // Test (amp) multiplication
6233        let expr_mul = &expr1 * &expr2;
6234        let result_mul = expr_mul
6235            .load(&dataset)
6236            .unwrap()
6237            .evaluate(&[])
6238            .expect("evaluation should succeed");
6239        assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
6240
6241        // Test (amp) division
6242        let expr_div = &expr1 / &expr3;
6243        let result_div = expr_div
6244            .load(&dataset)
6245            .unwrap()
6246            .evaluate(&[])
6247            .expect("evaluation should succeed");
6248        assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
6249
6250        // Test (amp) neg
6251        let expr_neg = -&expr3;
6252        let result_neg = expr_neg
6253            .load(&dataset)
6254            .unwrap()
6255            .evaluate(&[])
6256            .expect("evaluation should succeed");
6257        assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
6258
6259        // Test (expr) addition
6260        let expr_add2 = &expr_add + &expr_mul;
6261        let result_add2 = expr_add2
6262            .load(&dataset)
6263            .unwrap()
6264            .evaluate(&[])
6265            .expect("evaluation should succeed");
6266        assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
6267
6268        // Test (expr) subtraction
6269        let expr_sub2 = &expr_add - &expr_mul;
6270        let result_sub2 = expr_sub2
6271            .load(&dataset)
6272            .unwrap()
6273            .evaluate(&[])
6274            .expect("evaluation should succeed");
6275        assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
6276
6277        // Test (expr) multiplication
6278        let expr_mul2 = &expr_add * &expr_mul;
6279        let result_mul2 = expr_mul2
6280            .load(&dataset)
6281            .unwrap()
6282            .evaluate(&[])
6283            .expect("evaluation should succeed");
6284        assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
6285
6286        // Test (expr) division
6287        let expr_div2 = &expr_add / &expr_add2;
6288        let result_div2 = expr_div2
6289            .load(&dataset)
6290            .unwrap()
6291            .evaluate(&[])
6292            .expect("evaluation should succeed");
6293        assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
6294
6295        // Test (expr) neg
6296        let expr_neg2 = -&expr_mul2;
6297        let result_neg2 = expr_neg2
6298            .load(&dataset)
6299            .unwrap()
6300            .evaluate(&[])
6301            .expect("evaluation should succeed");
6302        assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
6303
6304        // Test (amp) real
6305        let expr_real = expr3.real();
6306        let result_real = expr_real
6307            .load(&dataset)
6308            .unwrap()
6309            .evaluate(&[])
6310            .expect("evaluation should succeed");
6311        assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
6312
6313        // Test (expr) real
6314        let expr_mul2_real = expr_mul2.real();
6315        let result_mul2_real = expr_mul2_real
6316            .load(&dataset)
6317            .unwrap()
6318            .evaluate(&[])
6319            .expect("evaluation should succeed");
6320        assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
6321
6322        // Test (amp) imag
6323        let expr_imag = expr3.imag();
6324        let result_imag = expr_imag
6325            .load(&dataset)
6326            .unwrap()
6327            .evaluate(&[])
6328            .expect("evaluation should succeed");
6329        assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
6330
6331        // Test (expr) imag
6332        let expr_mul2_imag = expr_mul2.imag();
6333        let result_mul2_imag = expr_mul2_imag
6334            .load(&dataset)
6335            .unwrap()
6336            .evaluate(&[])
6337            .expect("evaluation should succeed");
6338        assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
6339
6340        // Test (amp) conj
6341        let expr_conj = expr3.conj();
6342        let result_conj = expr_conj
6343            .load(&dataset)
6344            .unwrap()
6345            .evaluate(&[])
6346            .expect("evaluation should succeed");
6347        assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
6348
6349        // Test (expr) conj
6350        let expr_mul2_conj = expr_mul2.conj();
6351        let result_mul2_conj = expr_mul2_conj
6352            .load(&dataset)
6353            .unwrap()
6354            .evaluate(&[])
6355            .expect("evaluation should succeed");
6356        assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
6357
6358        // Test (amp) norm_sqr
6359        let expr_norm = expr1.norm_sqr();
6360        let result_norm = expr_norm
6361            .load(&dataset)
6362            .unwrap()
6363            .evaluate(&[])
6364            .expect("evaluation should succeed");
6365        assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
6366
6367        // Test (expr) norm_sqr
6368        let expr_mul2_norm = expr_mul2.norm_sqr();
6369        let result_mul2_norm = expr_mul2_norm
6370            .load(&dataset)
6371            .unwrap()
6372            .evaluate(&[])
6373            .expect("evaluation should succeed");
6374        assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
6375    }
6376
6377    #[test]
6378    fn test_amplitude_activation() {
6379        let expr1 = ComplexScalar::new(
6380            "const1",
6381            parameter!("const1_re_act", 1.0),
6382            parameter!("const1_im_act", 0.0),
6383        )
6384        .unwrap();
6385        let expr2 = ComplexScalar::new(
6386            "const2",
6387            parameter!("const2_re_act", 2.0),
6388            parameter!("const2_im_act", 0.0),
6389        )
6390        .unwrap();
6391
6392        let dataset = Arc::new(test_dataset());
6393        let expr = &expr1 + &expr2;
6394        let evaluator = expr.load(&dataset).unwrap();
6395
6396        // Test initial state (all active)
6397        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6398        assert_eq!(result[0], Complex64::new(3.0, 0.0));
6399
6400        // Test deactivation
6401        evaluator.deactivate_strict("const1").unwrap();
6402        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6403        assert_eq!(result[0], Complex64::new(2.0, 0.0));
6404
6405        // Test isolation
6406        evaluator.isolate_strict("const1").unwrap();
6407        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6408        assert_eq!(result[0], Complex64::new(1.0, 0.0));
6409
6410        // Test reactivation
6411        evaluator.activate_all();
6412        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6413        assert_eq!(result[0], Complex64::new(3.0, 0.0));
6414    }
6415
6416    #[test]
6417    fn test_gradient() {
6418        let expr1 = ComplexScalar::new(
6419            "parametric_1",
6420            parameter!("test_param_re_1"),
6421            parameter!("test_param_im_1"),
6422        )
6423        .unwrap();
6424        let expr2 = ComplexScalar::new(
6425            "parametric_2",
6426            parameter!("test_param_re_2"),
6427            parameter!("test_param_im_2"),
6428        )
6429        .unwrap();
6430
6431        let dataset = Arc::new(test_dataset());
6432        let params = vec![2.0, 3.0, 4.0, 5.0];
6433
6434        let expr = &expr1 + &expr2;
6435        let evaluator = expr.load(&dataset).unwrap();
6436
6437        let gradient = evaluator
6438            .evaluate_gradient(&params)
6439            .expect("evaluation should succeed");
6440
6441        assert_relative_eq!(gradient[0][0].re, 1.0);
6442        assert_relative_eq!(gradient[0][0].im, 0.0);
6443        assert_relative_eq!(gradient[0][1].re, 0.0);
6444        assert_relative_eq!(gradient[0][1].im, 1.0);
6445        assert_relative_eq!(gradient[0][2].re, 1.0);
6446        assert_relative_eq!(gradient[0][2].im, 0.0);
6447        assert_relative_eq!(gradient[0][3].re, 0.0);
6448        assert_relative_eq!(gradient[0][3].im, 1.0);
6449
6450        let expr = &expr1 - &expr2;
6451        let evaluator = expr.load(&dataset).unwrap();
6452
6453        let gradient = evaluator
6454            .evaluate_gradient(&params)
6455            .expect("evaluation should succeed");
6456
6457        assert_relative_eq!(gradient[0][0].re, 1.0);
6458        assert_relative_eq!(gradient[0][0].im, 0.0);
6459        assert_relative_eq!(gradient[0][1].re, 0.0);
6460        assert_relative_eq!(gradient[0][1].im, 1.0);
6461        assert_relative_eq!(gradient[0][2].re, -1.0);
6462        assert_relative_eq!(gradient[0][2].im, 0.0);
6463        assert_relative_eq!(gradient[0][3].re, 0.0);
6464        assert_relative_eq!(gradient[0][3].im, -1.0);
6465
6466        let expr = &expr1 * &expr2;
6467        let evaluator = expr.load(&dataset).unwrap();
6468
6469        let gradient = evaluator
6470            .evaluate_gradient(&params)
6471            .expect("evaluation should succeed");
6472
6473        assert_relative_eq!(gradient[0][0].re, 4.0);
6474        assert_relative_eq!(gradient[0][0].im, 5.0);
6475        assert_relative_eq!(gradient[0][1].re, -5.0);
6476        assert_relative_eq!(gradient[0][1].im, 4.0);
6477        assert_relative_eq!(gradient[0][2].re, 2.0);
6478        assert_relative_eq!(gradient[0][2].im, 3.0);
6479        assert_relative_eq!(gradient[0][3].re, -3.0);
6480        assert_relative_eq!(gradient[0][3].im, 2.0);
6481
6482        let expr = &expr1 / &expr2;
6483        let evaluator = expr.load(&dataset).unwrap();
6484
6485        let gradient = evaluator
6486            .evaluate_gradient(&params)
6487            .expect("evaluation should succeed");
6488
6489        assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
6490        assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
6491        assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
6492        assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
6493        assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
6494        assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
6495        assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
6496        assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
6497
6498        let expr = -(&expr1 * &expr2);
6499        let evaluator = expr.load(&dataset).unwrap();
6500
6501        let gradient = evaluator
6502            .evaluate_gradient(&params)
6503            .expect("evaluation should succeed");
6504
6505        assert_relative_eq!(gradient[0][0].re, -4.0);
6506        assert_relative_eq!(gradient[0][0].im, -5.0);
6507        assert_relative_eq!(gradient[0][1].re, 5.0);
6508        assert_relative_eq!(gradient[0][1].im, -4.0);
6509        assert_relative_eq!(gradient[0][2].re, -2.0);
6510        assert_relative_eq!(gradient[0][2].im, -3.0);
6511        assert_relative_eq!(gradient[0][3].re, 3.0);
6512        assert_relative_eq!(gradient[0][3].im, -2.0);
6513
6514        let expr = (&expr1 * &expr2).real();
6515        let evaluator = expr.load(&dataset).unwrap();
6516
6517        let gradient = evaluator
6518            .evaluate_gradient(&params)
6519            .expect("evaluation should succeed");
6520
6521        assert_relative_eq!(gradient[0][0].re, 4.0);
6522        assert_relative_eq!(gradient[0][0].im, 0.0);
6523        assert_relative_eq!(gradient[0][1].re, -5.0);
6524        assert_relative_eq!(gradient[0][1].im, 0.0);
6525        assert_relative_eq!(gradient[0][2].re, 2.0);
6526        assert_relative_eq!(gradient[0][2].im, 0.0);
6527        assert_relative_eq!(gradient[0][3].re, -3.0);
6528        assert_relative_eq!(gradient[0][3].im, 0.0);
6529
6530        let expr = (&expr1 * &expr2).imag();
6531        let evaluator = expr.load(&dataset).unwrap();
6532
6533        let gradient = evaluator
6534            .evaluate_gradient(&params)
6535            .expect("evaluation should succeed");
6536
6537        assert_relative_eq!(gradient[0][0].re, 5.0);
6538        assert_relative_eq!(gradient[0][0].im, 0.0);
6539        assert_relative_eq!(gradient[0][1].re, 4.0);
6540        assert_relative_eq!(gradient[0][1].im, 0.0);
6541        assert_relative_eq!(gradient[0][2].re, 3.0);
6542        assert_relative_eq!(gradient[0][2].im, 0.0);
6543        assert_relative_eq!(gradient[0][3].re, 2.0);
6544        assert_relative_eq!(gradient[0][3].im, 0.0);
6545
6546        let expr = (&expr1 * &expr2).conj();
6547        let evaluator = expr.load(&dataset).unwrap();
6548
6549        let gradient = evaluator
6550            .evaluate_gradient(&params)
6551            .expect("evaluation should succeed");
6552
6553        assert_relative_eq!(gradient[0][0].re, 4.0);
6554        assert_relative_eq!(gradient[0][0].im, -5.0);
6555        assert_relative_eq!(gradient[0][1].re, -5.0);
6556        assert_relative_eq!(gradient[0][1].im, -4.0);
6557        assert_relative_eq!(gradient[0][2].re, 2.0);
6558        assert_relative_eq!(gradient[0][2].im, -3.0);
6559        assert_relative_eq!(gradient[0][3].re, -3.0);
6560        assert_relative_eq!(gradient[0][3].im, -2.0);
6561
6562        let expr = (&expr1 * &expr2).norm_sqr();
6563        let evaluator = expr.load(&dataset).unwrap();
6564
6565        let gradient = evaluator
6566            .evaluate_gradient(&params)
6567            .expect("evaluation should succeed");
6568
6569        assert_relative_eq!(gradient[0][0].re, 164.0);
6570        assert_relative_eq!(gradient[0][0].im, 0.0);
6571        assert_relative_eq!(gradient[0][1].re, 246.0);
6572        assert_relative_eq!(gradient[0][1].im, 0.0);
6573        assert_relative_eq!(gradient[0][2].re, 104.0);
6574        assert_relative_eq!(gradient[0][2].im, 0.0);
6575        assert_relative_eq!(gradient[0][3].re, 130.0);
6576        assert_relative_eq!(gradient[0][3].im, 0.0);
6577    }
6578
6579    #[test]
6580    fn test_expression_function_gradients() {
6581        let expr1 = ComplexScalar::new(
6582            "function_parametric_1",
6583            parameter!("function_test_param_re_1"),
6584            parameter!("function_test_param_im_1"),
6585        )
6586        .unwrap();
6587        let expr2 = ComplexScalar::new(
6588            "function_parametric_2",
6589            parameter!("function_test_param_re_2"),
6590            parameter!("function_test_param_im_2"),
6591        )
6592        .unwrap();
6593
6594        let sin = expr1.sin();
6595        let cos = expr1.cos();
6596        let trig = &sin * &cos;
6597        let pow = expr1.pow(&expr2);
6598        let mut expr = expr1.sqrt();
6599        expr = &expr + &expr1.exp();
6600        expr = &expr + &expr1.powi(2);
6601        expr = &expr + &expr1.powf(1.7);
6602        expr = &expr + &trig;
6603        expr = &expr + &expr1.log();
6604        expr = &expr + &expr1.cis();
6605        expr = &expr + &pow;
6606
6607        let dataset = Arc::new(test_dataset());
6608        let evaluator = expr.load(&dataset).unwrap();
6609        let params = vec![2.0, 0.5, 1.2, -0.3];
6610        let gradient = evaluator
6611            .evaluate_gradient(&params)
6612            .expect("evaluation should succeed");
6613        let eps = 1e-6;
6614
6615        for param_index in 0..params.len() {
6616            let mut plus = params.clone();
6617            plus[param_index] += eps;
6618            let mut minus = params.clone();
6619            minus[param_index] -= eps;
6620            let finite_diff = (evaluator
6621                .evaluate(&plus)
6622                .expect("evaluation should succeed")[0]
6623                - evaluator
6624                    .evaluate(&minus)
6625                    .expect("evaluation should succeed")[0])
6626                / Complex64::new(2.0 * eps, 0.0);
6627
6628            assert_relative_eq!(
6629                gradient[0][param_index].re,
6630                finite_diff.re,
6631                epsilon = 1e-6,
6632                max_relative = 1e-6
6633            );
6634            assert_relative_eq!(
6635                gradient[0][param_index].im,
6636                finite_diff.im,
6637                epsilon = 1e-6,
6638                max_relative = 1e-6
6639            );
6640        }
6641    }
6642
6643    #[test]
6644    fn test_zeros_and_ones() {
6645        let amp = ComplexScalar::new(
6646            "parametric",
6647            parameter!("test_param_re"),
6648            parameter!("fixed_two", 2.0),
6649        )
6650        .unwrap();
6651        let dataset = Arc::new(test_dataset());
6652        let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
6653        let evaluator = expr.load(&dataset).unwrap();
6654
6655        let params = vec![2.0];
6656        let value = evaluator
6657            .evaluate(&params)
6658            .expect("evaluation should succeed");
6659        let gradient = evaluator
6660            .evaluate_gradient(&params)
6661            .expect("evaluation should succeed");
6662
6663        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the value should be x^2 + 4
6664        assert_relative_eq!(value[0].re, 8.0);
6665        assert_relative_eq!(value[0].im, 0.0);
6666
6667        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the derivative should be 2x
6668        assert_relative_eq!(gradient[0][0].re, 4.0);
6669        assert_relative_eq!(gradient[0][0].im, 0.0);
6670    }
6671    #[test]
6672    fn test_default_build_uses_lowered_expression_runtime() {
6673        let expr = ComplexScalar::new(
6674            "opt_in_gate",
6675            parameter!("opt_in_gate_re", 2.0),
6676            parameter!("opt_in_gate_im", 0.0),
6677        )
6678        .unwrap()
6679        .norm_sqr();
6680        let dataset = Arc::new(test_dataset());
6681        let evaluator = expr.load(&dataset).unwrap();
6682
6683        let diagnostics = evaluator.expression_runtime_diagnostics();
6684        assert!(diagnostics.ir_planning_enabled);
6685        assert!(diagnostics.lowered_value_program_present);
6686        assert!(diagnostics.lowered_gradient_program_present);
6687        assert!(diagnostics.lowered_value_gradient_program_present);
6688        assert_eq!(
6689            evaluator.evaluate(&[]).expect("evaluation should succeed")[0],
6690            Complex64::new(4.0, 0.0)
6691        );
6692    }
6693
6694    #[test]
6695    fn parameter_name_only_creates_free_parameter() {
6696        let p = parameter!("mass");
6697
6698        assert_eq!(p.name(), "mass");
6699        assert_eq!(p.fixed(), None);
6700        assert_eq!(p.initial(), None);
6701        assert_eq!(p.bounds(), (None, None));
6702        assert_eq!(p.unit(), None);
6703        assert_eq!(p.latex(), None);
6704        assert_eq!(p.description(), None);
6705        assert!(p.is_free());
6706        assert!(!p.is_fixed());
6707    }
6708
6709    #[test]
6710    fn parameter_name_and_value_creates_fixed_parameter() {
6711        let p = parameter!("width", 0.15);
6712
6713        assert_eq!(p.name(), "width");
6714        assert_eq!(p.fixed(), Some(0.15));
6715        assert_eq!(p.initial(), Some(0.15));
6716        assert!(p.is_fixed());
6717        assert!(!p.is_free());
6718    }
6719
6720    #[test]
6721    fn keyword_initial_sets_initial_only() {
6722        let p = parameter!("alpha", initial: 1.25);
6723
6724        assert_eq!(p.name(), "alpha");
6725        assert_eq!(p.fixed(), None);
6726        assert_eq!(p.initial(), Some(1.25));
6727        assert_eq!(p.bounds(), (None, None));
6728        assert!(p.is_free());
6729    }
6730
6731    #[test]
6732    fn keyword_fixed_sets_fixed_and_initial() {
6733        let p = parameter!("beta", fixed: 2.5);
6734
6735        assert_eq!(p.name(), "beta");
6736        assert_eq!(p.fixed(), Some(2.5));
6737        assert_eq!(p.initial(), Some(2.5));
6738        assert!(p.is_fixed());
6739    }
6740
6741    #[test]
6742    fn bounds_accept_plain_numbers() {
6743        let p = parameter!("x", bounds: (0.0, 10.0));
6744
6745        assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
6746    }
6747
6748    #[test]
6749    fn bounds_accept_none_and_number() {
6750        let p = parameter!("x", bounds: (None, 10.0));
6751
6752        assert_eq!(p.bounds(), (None, Some(10.0)));
6753    }
6754
6755    #[test]
6756    fn bounds_accept_number_and_none() {
6757        let p = parameter!("x", bounds: (-1.0, None));
6758
6759        assert_eq!(p.bounds(), (Some(-1.0), None));
6760    }
6761
6762    #[test]
6763    fn bounds_accept_both_none() {
6764        let p = parameter!("x", bounds: (None, None));
6765
6766        assert_eq!(p.bounds(), (None, None));
6767    }
6768
6769    #[test]
6770    fn bounds_accept_arbitrary_expressions() {
6771        let lo = 1.0;
6772        let hi = 2.0 * 3.0;
6773        let p = parameter!("x", bounds: (lo - 0.5, hi));
6774
6775        assert_eq!(p.bounds(), (Some(0.5), Some(6.0)));
6776    }
6777
6778    #[test]
6779    fn multiple_keyword_arguments_work_together() {
6780        let p = parameter!(
6781            "gamma",
6782            initial: 1.0,
6783            bounds: (0.0, 5.0),
6784            unit: "GeV",
6785            latex: r"\gamma",
6786            description: "test parameter",
6787        );
6788
6789        assert_eq!(p.name(), "gamma");
6790        assert_eq!(p.fixed(), None);
6791        assert_eq!(p.initial(), Some(1.0));
6792        assert_eq!(p.bounds(), (Some(0.0), Some(5.0)));
6793        assert_eq!(p.unit().as_deref(), Some("GeV"));
6794        assert_eq!(p.latex().as_deref(), Some(r"\gamma"));
6795        assert_eq!(p.description().as_deref(), Some("test parameter"));
6796    }
6797
6798    #[test]
6799    fn fixed_can_be_combined_with_other_fields() {
6800        let p = parameter!(
6801            "delta",
6802            fixed: 3.0,
6803            bounds: (0.0, 10.0),
6804            unit: "rad",
6805        );
6806
6807        assert_eq!(p.name(), "delta");
6808        assert_eq!(p.fixed(), Some(3.0));
6809        assert_eq!(p.initial(), Some(3.0));
6810        assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
6811        assert_eq!(p.unit().as_deref(), Some("rad"));
6812    }
6813
6814    #[test]
6815    fn trailing_comma_is_accepted() {
6816        let p = parameter!(
6817            "eps",
6818            initial: 0.5,
6819            bounds: (None, 1.0),
6820            unit: "arb",
6821        );
6822
6823        assert_eq!(p.initial(), Some(0.5));
6824        assert_eq!(p.bounds(), (None, Some(1.0)));
6825        assert_eq!(p.unit().as_deref(), Some("arb"));
6826    }
6827
6828    #[test]
6829    fn test_parameter_registration() {
6830        let expr = ComplexScalar::new(
6831            "parametric",
6832            parameter!("test_param_re"),
6833            parameter!("fixed_two", 2.0),
6834        )
6835        .unwrap();
6836        let parameters = expr.parameters().free().names();
6837        assert_eq!(parameters.len(), 1);
6838        assert_eq!(parameters[0], "test_param_re");
6839    }
6840
6841    #[test]
6842    fn test_duplicate_amplitude_tag_registration_is_allowed() {
6843        let amp1 = ComplexScalar::new(
6844            "same_name",
6845            parameter!("dup_re1", 1.0),
6846            parameter!("dup_im1", 0.0),
6847        )
6848        .unwrap();
6849        let amp2 = ComplexScalar::new(
6850            "same_name",
6851            parameter!("dup_re2", 2.0),
6852            parameter!("dup_im2", 0.0),
6853        )
6854        .unwrap();
6855        let expr = amp1 + amp2;
6856        assert_eq!(
6857            expr.parameters().fixed().names(),
6858            vec!["dup_re1", "dup_im1", "dup_re2", "dup_im2"]
6859        );
6860    }
6861
6862    #[test]
6863    fn test_tree_printing() {
6864        let amp1 = ComplexScalar::new(
6865            "parametric_1",
6866            parameter!("test_param_re_1"),
6867            parameter!("test_param_im_1"),
6868        )
6869        .unwrap();
6870        let amp2 = ComplexScalar::new(
6871            "parametric_2",
6872            parameter!("test_param_re_2"),
6873            parameter!("test_param_im_2"),
6874        )
6875        .unwrap();
6876        let expr =
6877            &amp1.real() + &amp2.conj().imag() + Expression::one() * Complex64::new(-1.4, 2.0)
6878                - Expression::zero() / 1.0
6879                + (&amp1 * &amp2).norm_sqr();
6880        assert_eq!(
6881            expr.to_string(),
6882            concat!(
6883                "+\n",
6884                "├─ -\n",
6885                "│  ├─ +\n",
6886                "│  │  ├─ +\n",
6887                "│  │  │  ├─ Re\n",
6888                "│  │  │  │  └─ parametric_1(id=0)\n",
6889                "│  │  │  └─ Im\n",
6890                "│  │  │     └─ *\n",
6891                "│  │  │        └─ parametric_2(id=1)\n",
6892                "│  │  └─ ×\n",
6893                "│  │     ├─ 1 (exact)\n",
6894                "│  │     └─ -1.4+2i\n",
6895                "│  └─ ÷\n",
6896                "│     ├─ 0 (exact)\n",
6897                "│     └─ 1 (exact)\n",
6898                "└─ NormSqr\n",
6899                "   └─ ×\n",
6900                "      ├─ parametric_1(id=0)\n",
6901                "      └─ parametric_2(id=1)\n",
6902            )
6903        );
6904    }
6905}