Skip to main content

laddu_core/expression/
evaluator.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    sync::{
5        atomic::{AtomicU64, Ordering},
6        Arc,
7    },
8    time::Instant,
9};
10
11use auto_ops::*;
12use nalgebra::DVector;
13use num::complex::Complex64;
14use parking_lot::RwLock;
15#[cfg(feature = "rayon")]
16use rayon::prelude::*;
17use serde::{Deserialize, Serialize};
18
19use super::{ir, lowered};
20
21static COMPUTE_AMPLITUDE_COUNTER: AtomicU64 = AtomicU64::new(0);
22static AMPLITUDE_USE_SITE_COUNTER: AtomicU64 = AtomicU64::new(0);
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
25struct ComputeAmplitudeId(u64);
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
28struct AmplitudeUseSiteId(u64);
29
30fn next_compute_amplitude_id() -> ComputeAmplitudeId {
31    ComputeAmplitudeId(COMPUTE_AMPLITUDE_COUNTER.fetch_add(1, Ordering::Relaxed))
32}
33
34fn next_amplitude_use_site_id() -> AmplitudeUseSiteId {
35    AmplitudeUseSiteId(AMPLITUDE_USE_SITE_COUNTER.fetch_add(1, Ordering::Relaxed))
36}
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38/// Dependence classification used by expression-IR diagnostics.
39pub enum ExpressionDependence {
40    /// Depends only on fixed/free parameter values.
41    ParameterOnly,
42    /// Depends only on event-local cached values.
43    CacheOnly,
44    /// Depends on both parameter values and cached event values.
45    Mixed,
46}
47impl From<ir::DependenceClass> for ExpressionDependence {
48    fn from(value: ir::DependenceClass) -> Self {
49        match value {
50            ir::DependenceClass::ParameterOnly => Self::ParameterOnly,
51            ir::DependenceClass::CacheOnly => Self::CacheOnly,
52            ir::DependenceClass::Mixed => Self::Mixed,
53        }
54    }
55}
56#[derive(Clone, Debug, PartialEq, Eq)]
57/// Explain/debug view of the IR normalization planning decomposition.
58pub struct NormalizationPlanExplain {
59    /// Dependence classification at the expression root.
60    pub root_dependence: ExpressionDependence,
61    /// Warning-level diagnostics collected during planning.
62    pub warnings: Vec<String>,
63    /// Candidate multiply node indices identified as separable.
64    pub separable_mul_candidate_nodes: Vec<usize>,
65    /// Candidate separable node indices selected for caching.
66    pub cached_separable_nodes: Vec<usize>,
67    /// Node indices planned for residual per-event evaluation.
68    pub residual_terms: Vec<usize>,
69}
70#[derive(Clone, Debug, PartialEq, Eq)]
71/// Explain/debug view of amplitude execution sets used by normalization evaluation.
72pub struct NormalizationExecutionSetsExplain {
73    /// Amplitudes required to evaluate parameter factors for cached separable terms.
74    pub cached_parameter_amplitudes: Vec<usize>,
75    /// Amplitudes required to evaluate cache factors for cached separable terms.
76    pub cached_cache_amplitudes: Vec<usize>,
77    /// Amplitudes required for residual (non-cached) normalization evaluation.
78    pub residual_amplitudes: Vec<usize>,
79}
80#[derive(Clone, Debug, PartialEq)]
81/// Load-time precomputed integral metadata for a separable cached term.
82pub struct PrecomputedCachedIntegral {
83    /// Node index of the separable multiplication term.
84    pub mul_node_index: usize,
85    /// Node index of the parameter-dependent factor.
86    pub parameter_node_index: usize,
87    /// Node index of the cache-dependent factor.
88    pub cache_node_index: usize,
89    /// Signed extraction coefficient induced by Add/Sub/Neg ancestry to the root.
90    pub coefficient: i32,
91    /// Weighted sum over local events of the cache-dependent factor.
92    pub weighted_cache_sum: Complex64,
93}
94#[derive(Clone, Debug, PartialEq)]
95/// Parameter-gradient contribution for a load-time precomputed cached integral term.
96pub struct PrecomputedCachedIntegralGradientTerm {
97    /// Node index of the separable multiplication term.
98    pub mul_node_index: usize,
99    /// Node index of the parameter-dependent factor.
100    pub parameter_node_index: usize,
101    /// Node index of the cache-dependent factor.
102    pub cache_node_index: usize,
103    /// Signed extraction coefficient induced by Add/Sub/Neg ancestry to the root.
104    pub coefficient: i32,
105    /// Gradient contribution `(d/dp parameter_factor) * weighted_cache_sum`.
106    pub weighted_gradient: DVector<Complex64>,
107}
108#[derive(Clone, Debug, PartialEq, Eq, Hash)]
109struct CachedIntegralCacheKey {
110    active_mask: Vec<bool>,
111    n_events_local: usize,
112    weights_local_len: usize,
113    weighted_sum_bits: u64,
114    weights_ptr: usize,
115}
116#[derive(Clone, Debug)]
117struct CachedIntegralCacheState {
118    key: CachedIntegralCacheKey,
119    expression_ir: ir::ExpressionIR,
120    values: Vec<PrecomputedCachedIntegral>,
121    execution_sets: ir::NormalizationExecutionSets,
122}
123#[derive(Clone, Debug)]
124struct LoweredArtifactCacheState {
125    parameter_node_indices: Vec<usize>,
126    mul_node_indices: Vec<usize>,
127    lowered_parameter_factors: Vec<Option<lowered::LoweredFactorRuntime>>,
128    residual_runtime: Option<lowered::LoweredExpressionRuntime>,
129    lowered_runtime: lowered::LoweredExpressionRuntime,
130}
131#[derive(Clone)]
132struct ExpressionSpecializationState {
133    cached_integrals: Arc<CachedIntegralCacheState>,
134    lowered_artifacts: Arc<LoweredArtifactCacheState>,
135}
136#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
137/// Debug/benchmark counters for active-mask specialization reuse under `expression-ir`.
138pub struct ExpressionSpecializationMetrics {
139    /// Number of specialization cache hits served without recompilation.
140    pub cache_hits: usize,
141    /// Number of specialization cache misses that required a fresh compile/lower pass.
142    pub cache_misses: usize,
143}
144#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
145/// Staged compile/lowering metrics for expression-IR construction and specialization refreshes.
146pub struct ExpressionCompileMetrics {
147    /// Nanoseconds spent compiling the semantic expression tree into IR during initial load.
148    pub initial_ir_compile_nanos: u64,
149    /// Nanoseconds spent precomputing cached-integral planning artifacts during initial load.
150    pub initial_cached_integrals_nanos: u64,
151    /// Nanoseconds spent lowering IR-derived runtimes during initial load.
152    pub initial_lowering_nanos: u64,
153    /// Number of specialization cache hits restored without recompilation.
154    pub specialization_cache_hits: usize,
155    /// Number of specialization cache misses that required recompilation.
156    pub specialization_cache_misses: usize,
157    /// Accumulated nanoseconds spent compiling active-mask-specialized IR after initial load.
158    pub specialization_ir_compile_nanos: u64,
159    /// Accumulated nanoseconds spent recomputing cached-integral planning artifacts after load.
160    pub specialization_cached_integrals_nanos: u64,
161    /// Accumulated nanoseconds spent lowering specialized runtimes after load.
162    pub specialization_lowering_nanos: u64,
163    /// Number of specialization rebuilds that reused cached lowered artifacts.
164    pub specialization_lowering_cache_hits: usize,
165    /// Number of specialization rebuilds that had to lower fresh artifacts.
166    pub specialization_lowering_cache_misses: usize,
167    /// Accumulated nanoseconds spent restoring specializations from cache.
168    pub specialization_cache_restore_nanos: u64,
169}
170impl From<ir::NormalizationPlanExplain> for NormalizationPlanExplain {
171    fn from(value: ir::NormalizationPlanExplain) -> Self {
172        Self {
173            root_dependence: value.root_dependence.into(),
174            warnings: value.warnings,
175            separable_mul_candidate_nodes: value
176                .separable_mul_candidates
177                .into_iter()
178                .map(|candidate| candidate.node_index)
179                .collect(),
180            cached_separable_nodes: value.cached_separable_nodes,
181            residual_terms: value.residual_terms,
182        }
183    }
184}
185impl From<ir::NormalizationExecutionSets> for NormalizationExecutionSetsExplain {
186    fn from(value: ir::NormalizationExecutionSets) -> Self {
187        Self {
188            cached_parameter_amplitudes: value.cached_parameter_amplitudes,
189            cached_cache_amplitudes: value.cached_cache_amplitudes,
190            residual_amplitudes: value.residual_amplitudes,
191        }
192    }
193}
194impl From<ExpressionDependence> for ir::DependenceClass {
195    fn from(value: ExpressionDependence) -> Self {
196        match value {
197            ExpressionDependence::ParameterOnly => Self::ParameterOnly,
198            ExpressionDependence::CacheOnly => Self::CacheOnly,
199            ExpressionDependence::Mixed => Self::Mixed,
200        }
201    }
202}
203
204#[cfg(feature = "mpi")]
205use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
206
207#[cfg(feature = "mpi")]
208use crate::mpi::LadduMPI;
209#[cfg(feature = "execution-context-prototype")]
210use crate::ExecutionContext;
211#[cfg(all(feature = "execution-context-prototype", feature = "rayon"))]
212use crate::ThreadPolicy;
213use crate::{
214    amplitude::{Amplitude, AmplitudeUseSite},
215    data::Dataset,
216    parameters::ParameterMap,
217    resources::{Cache, Parameters, Resources},
218    LadduError, LadduResult,
219};
220
221/// A holder struct that owns both an expression tree and the registered amplitudes.
222#[allow(missing_docs)]
223#[derive(Clone, Serialize, Deserialize, Default)]
224pub struct Expression {
225    registry: ExpressionRegistry,
226    tree: ExpressionNode,
227}
228
229#[derive(Clone, Serialize, Deserialize)]
230#[allow(missing_docs)]
231#[derive(Default)]
232pub struct ExpressionRegistry {
233    amplitudes: Vec<Box<dyn Amplitude>>,
234    amplitude_names: Vec<String>,
235    amplitude_ids: Vec<ComputeAmplitudeId>,
236    amplitude_use_sites: Vec<AmplitudeUseSite>,
237    amplitude_use_site_ids: Vec<AmplitudeUseSiteId>,
238    resources: Resources,
239}
240
241impl ExpressionRegistry {
242    fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
243        let mut resources = Resources::default();
244        let aid = amplitude.register(&mut resources)?;
245        let compute_id = next_compute_amplitude_id();
246        let use_site_id = next_amplitude_use_site_id();
247        resources.configure_amplitude_tags(std::slice::from_ref(&aid.0));
248        Ok(Self {
249            amplitudes: vec![amplitude],
250            amplitude_names: vec![aid.0.display_label()],
251            amplitude_ids: vec![compute_id],
252            amplitude_use_sites: vec![AmplitudeUseSite {
253                amplitude_index: 0,
254                tags: aid.0,
255            }],
256            amplitude_use_site_ids: vec![use_site_id],
257            resources,
258        })
259    }
260
261    fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
262        let mut resources = Resources::default();
263        let mut amplitudes = Vec::new();
264        let mut amplitude_ids = Vec::new();
265        let mut compute_id_to_index = HashMap::new();
266        let mut semantic_key_to_index = HashMap::new();
267
268        let mut left_compute_map = Vec::with_capacity(self.amplitudes.len());
269        for (amp_index, (amp, amp_id)) in
270            self.amplitudes.iter().zip(&self.amplitude_ids).enumerate()
271        {
272            let semantic_key = amp.semantic_key();
273            let mut cloned_amp = dyn_clone::clone_box(&**amp);
274            let aid = cloned_amp.register(&mut resources)?;
275            if let Some(key) = semantic_key.clone() {
276                semantic_key_to_index.insert(key, aid.1);
277            }
278            amplitudes.push(cloned_amp);
279            amplitude_ids.push(*amp_id);
280            compute_id_to_index.insert(*amp_id, aid.1);
281            left_compute_map.push(aid.1);
282            debug_assert_eq!(amp_index, aid.1);
283        }
284
285        let mut right_compute_map = Vec::with_capacity(other.amplitudes.len());
286        for (amp, amp_id) in other.amplitudes.iter().zip(&other.amplitude_ids) {
287            if let Some(existing) = compute_id_to_index.get(amp_id) {
288                right_compute_map.push(*existing);
289                continue;
290            }
291            let incoming_semantic_key = amp.semantic_key();
292            if let Some(existing) = incoming_semantic_key
293                .as_ref()
294                .and_then(|key| semantic_key_to_index.get(key))
295            {
296                right_compute_map.push(*existing);
297                continue;
298            }
299            let mut cloned_amp = dyn_clone::clone_box(&**amp);
300            let aid = cloned_amp.register(&mut resources)?;
301            if let Some(key) = incoming_semantic_key.clone() {
302                semantic_key_to_index.insert(key, aid.1);
303            }
304            amplitudes.push(cloned_amp);
305            amplitude_ids.push(*amp_id);
306            compute_id_to_index.insert(*amp_id, aid.1);
307            right_compute_map.push(aid.1);
308        }
309
310        let mut amplitude_use_sites = Vec::new();
311        let mut amplitude_use_site_ids = Vec::new();
312        let mut amplitude_names = Vec::new();
313        let mut use_site_id_to_index = HashMap::new();
314
315        let mut left_map = Vec::with_capacity(self.amplitude_use_sites.len());
316        for (use_site, use_site_id) in self
317            .amplitude_use_sites
318            .iter()
319            .zip(&self.amplitude_use_site_ids)
320        {
321            let mapped_index = left_compute_map[use_site.amplitude_index];
322            let new_index = amplitude_use_sites.len();
323            left_map.push(new_index);
324            use_site_id_to_index.insert(*use_site_id, new_index);
325            amplitude_use_site_ids.push(*use_site_id);
326            amplitude_names.push(use_site.tags.display_label());
327            amplitude_use_sites.push(AmplitudeUseSite {
328                amplitude_index: mapped_index,
329                tags: use_site.tags.clone(),
330            });
331        }
332
333        let mut right_map = Vec::with_capacity(other.amplitude_use_sites.len());
334        for (use_site, use_site_id) in other
335            .amplitude_use_sites
336            .iter()
337            .zip(&other.amplitude_use_site_ids)
338        {
339            if let Some(existing) = use_site_id_to_index.get(use_site_id) {
340                right_map.push(*existing);
341                continue;
342            }
343            let mapped_index = right_compute_map[use_site.amplitude_index];
344            let new_index = amplitude_use_sites.len();
345            right_map.push(new_index);
346            use_site_id_to_index.insert(*use_site_id, new_index);
347            amplitude_use_site_ids.push(*use_site_id);
348            amplitude_names.push(use_site.tags.display_label());
349            amplitude_use_sites.push(AmplitudeUseSite {
350                amplitude_index: mapped_index,
351                tags: use_site.tags.clone(),
352            });
353        }
354        let use_site_tags = amplitude_use_sites
355            .iter()
356            .map(|use_site| use_site.tags.clone())
357            .collect::<Vec<_>>();
358        resources.configure_amplitude_tags(&use_site_tags);
359
360        Ok((
361            Self {
362                amplitudes,
363                amplitude_names,
364                amplitude_ids,
365                amplitude_use_sites,
366                amplitude_use_site_ids,
367                resources,
368            },
369            left_map,
370            right_map,
371        ))
372    }
373}
374
375/// Expression tree used by [`Expression`].
376#[allow(missing_docs)]
377#[derive(Clone, Serialize, Deserialize, Default, Debug)]
378pub enum ExpressionNode {
379    #[default]
380    /// A expression equal to zero.
381    Zero,
382    /// A expression equal to one.
383    One,
384    /// A real-valued constant.
385    Constant(f64),
386    /// A complex-valued constant.
387    ComplexConstant(Complex64),
388    /// A registered [`Amplitude`] referenced by index.
389    Amp(usize),
390    /// The sum of two [`ExpressionNode`]s.
391    Add(Box<ExpressionNode>, Box<ExpressionNode>),
392    /// The difference of two [`ExpressionNode`]s.
393    Sub(Box<ExpressionNode>, Box<ExpressionNode>),
394    /// The product of two [`ExpressionNode`]s.
395    Mul(Box<ExpressionNode>, Box<ExpressionNode>),
396    /// The division of two [`ExpressionNode`]s.
397    Div(Box<ExpressionNode>, Box<ExpressionNode>),
398    /// The additive inverse of an [`ExpressionNode`].
399    Neg(Box<ExpressionNode>),
400    /// The real part of an [`ExpressionNode`].
401    Real(Box<ExpressionNode>),
402    /// The imaginary part of an [`ExpressionNode`].
403    Imag(Box<ExpressionNode>),
404    /// The complex conjugate of an [`ExpressionNode`].
405    Conj(Box<ExpressionNode>),
406    /// The absolute square of an [`ExpressionNode`].
407    NormSqr(Box<ExpressionNode>),
408    Sqrt(Box<ExpressionNode>),
409    Pow(Box<ExpressionNode>, Box<ExpressionNode>),
410    PowI(Box<ExpressionNode>, i32),
411    PowF(Box<ExpressionNode>, f64),
412    Exp(Box<ExpressionNode>),
413    Sin(Box<ExpressionNode>),
414    Cos(Box<ExpressionNode>),
415    Log(Box<ExpressionNode>),
416    Cis(Box<ExpressionNode>),
417}
418
419impl ExpressionNode {
420    fn remap(&self, mapping: &[usize]) -> Self {
421        match self {
422            Self::Amp(idx) => Self::Amp(mapping[*idx]),
423            Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
424            Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
425            Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
426            Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
427            Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
428            Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
429            Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
430            Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
431            Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
432            Self::Zero => Self::Zero,
433            Self::One => Self::One,
434            Self::Constant(v) => Self::Constant(*v),
435            Self::ComplexConstant(v) => Self::ComplexConstant(*v),
436            Self::Sqrt(a) => Self::Sqrt(Box::new(a.remap(mapping))),
437            Self::Pow(a, b) => Self::Pow(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
438            Self::PowI(a, power) => Self::PowI(Box::new(a.remap(mapping)), *power),
439            Self::PowF(a, power) => Self::PowF(Box::new(a.remap(mapping)), *power),
440            Self::Exp(a) => Self::Exp(Box::new(a.remap(mapping))),
441            Self::Sin(a) => Self::Sin(Box::new(a.remap(mapping))),
442            Self::Cos(a) => Self::Cos(Box::new(a.remap(mapping))),
443            Self::Log(a) => Self::Log(Box::new(a.remap(mapping))),
444            Self::Cis(a) => Self::Cis(Box::new(a.remap(mapping))),
445        }
446    }
447}
448
449impl From<f64> for Expression {
450    fn from(value: f64) -> Self {
451        if value == 0.0 {
452            Self {
453                registry: ExpressionRegistry::default(),
454                tree: ExpressionNode::Zero,
455            }
456        } else if value == 1.0 {
457            Self {
458                registry: ExpressionRegistry::default(),
459                tree: ExpressionNode::One,
460            }
461        } else {
462            Self {
463                registry: ExpressionRegistry::default(),
464                tree: ExpressionNode::Constant(value),
465            }
466        }
467    }
468}
469impl From<&f64> for Expression {
470    fn from(value: &f64) -> Self {
471        (*value).into()
472    }
473}
474impl From<Complex64> for Expression {
475    fn from(value: Complex64) -> Self {
476        if value == Complex64::ZERO {
477            Self {
478                registry: ExpressionRegistry::default(),
479                tree: ExpressionNode::Zero,
480            }
481        } else if value == Complex64::ONE {
482            Self {
483                registry: ExpressionRegistry::default(),
484                tree: ExpressionNode::One,
485            }
486        } else {
487            Self {
488                registry: ExpressionRegistry::default(),
489                tree: ExpressionNode::ComplexConstant(value),
490            }
491        }
492    }
493}
494impl From<&Complex64> for Expression {
495    fn from(value: &Complex64) -> Self {
496        (*value).into()
497    }
498}
499
500impl Expression {
501    /// Build an [`Expression`] from a single [`Amplitude`].
502    pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
503        let registry = ExpressionRegistry::singleton(amplitude)?;
504        Ok(Self {
505            tree: ExpressionNode::Amp(0),
506            registry,
507        })
508    }
509
510    /// Create an expression representing zero, the additive identity.
511    pub fn zero() -> Self {
512        Self {
513            registry: ExpressionRegistry::default(),
514            tree: ExpressionNode::Zero,
515        }
516    }
517
518    /// Create an expression representing one, the multiplicative identity.
519    pub fn one() -> Self {
520        Self {
521            registry: ExpressionRegistry::default(),
522            tree: ExpressionNode::One,
523        }
524    }
525
526    fn binary_op(
527        a: &Expression,
528        b: &Expression,
529        build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
530    ) -> Expression {
531        let (registry, left_map, right_map) = a
532            .registry
533            .merge(&b.registry)
534            .expect("merging expression registries should not fail");
535        let left_tree = a.tree.remap(&left_map);
536        let right_tree = b.tree.remap(&right_map);
537        Expression {
538            registry,
539            tree: build(Box::new(left_tree), Box::new(right_tree)),
540        }
541    }
542
543    fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
544        Expression {
545            registry: a.registry.clone(),
546            tree: build(Box::new(a.tree.clone())),
547        }
548    }
549
550    /// Get the parameters used by this expression.
551    pub fn parameters(&self) -> ParameterMap {
552        self.registry.resources.parameters()
553    }
554
555    /// Number of free parameters.
556    pub fn n_free(&self) -> usize {
557        self.registry.resources.n_free_parameters()
558    }
559
560    /// Number of fixed parameters.
561    pub fn n_fixed(&self) -> usize {
562        self.registry.resources.n_fixed_parameters()
563    }
564
565    /// Total number of parameters.
566    pub fn n_parameters(&self) -> usize {
567        self.registry.resources.n_parameters()
568    }
569
570    /// Returns a tree-like diagnostic snapshot of this expression's compiled form.
571    ///
572    /// This compiles the expression on each call with every registered amplitude active. Use
573    /// [`Evaluator::compiled_expression`] when you need the compiled form for a loaded evaluator's
574    /// current active-amplitude mask.
575    pub fn compiled_expression(&self) -> CompiledExpression {
576        let active_amplitudes = vec![true; self.registry.amplitude_use_sites.len()];
577        let amplitude_dependencies = self
578            .registry
579            .amplitude_use_sites
580            .iter()
581            .map(|use_site| {
582                ir::DependenceClass::from(
583                    self.registry.amplitudes[use_site.amplitude_index].dependence_hint(),
584                )
585            })
586            .collect::<Vec<_>>();
587        let amplitude_realness = self
588            .registry
589            .amplitude_use_sites
590            .iter()
591            .map(|use_site| self.registry.amplitudes[use_site.amplitude_index].real_valued_hint())
592            .collect::<Vec<_>>();
593        let expression_ir = ir::compile_expression_ir_with_real_hints(
594            &self.tree,
595            &active_amplitudes,
596            &amplitude_dependencies,
597            &amplitude_realness,
598        );
599        CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
600    }
601
602    /// Fix a parameter used by this expression's evaluator resources.
603    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
604        self.registry.resources.fix_parameter(name, value)
605    }
606
607    /// Mark a parameter used by this expression's evaluator resources as free.
608    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
609        self.registry.resources.free_parameter(name)
610    }
611
612    /// Return a new [`Expression`] with a single parameter renamed.
613    pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
614        self.registry.resources.rename_parameter(old, new)
615    }
616
617    /// Return a new [`Expression`] with several parameters renamed.
618    pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
619        self.registry.resources.rename_parameters(mapping)
620    }
621
622    /// Load an [`Expression`] against a dataset, binding amplitudes and reserving caches.
623    pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
624        let mut resources = self.registry.resources.clone();
625        let metadata = dataset.metadata();
626        resources.reserve_cache(dataset.n_events_local());
627        resources.refresh_active_indices();
628        let parameter_map = resources.parameter_map.clone();
629        let mut amplitudes: Vec<Box<dyn Amplitude>> = self
630            .registry
631            .amplitudes
632            .iter()
633            .map(|amp| dyn_clone::clone_box(&**amp))
634            .collect();
635        {
636            for amplitude in amplitudes.iter_mut() {
637                amplitude.bind(metadata)?;
638                amplitude.precompute_all(dataset, &mut resources);
639            }
640        }
641        let ir_compile_start = Instant::now();
642        let expression_ir = {
643            let mut active_amplitudes = vec![false; self.registry.amplitude_use_sites.len()];
644            for &index in resources.active_indices() {
645                active_amplitudes[index] = true;
646            }
647            let amplitude_dependencies = self
648                .registry
649                .amplitude_use_sites
650                .iter()
651                .map(|use_site| {
652                    ir::DependenceClass::from(
653                        amplitudes[use_site.amplitude_index].dependence_hint(),
654                    )
655                })
656                .collect::<Vec<_>>();
657            let amplitude_realness = self
658                .registry
659                .amplitude_use_sites
660                .iter()
661                .map(|use_site| amplitudes[use_site.amplitude_index].real_valued_hint())
662                .collect::<Vec<_>>();
663            ir::compile_expression_ir_with_real_hints(
664                &self.tree,
665                &active_amplitudes,
666                &amplitude_dependencies,
667                &amplitude_realness,
668            )
669        };
670        let initial_ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
671        let cached_integrals_start = Instant::now();
672        let cached_integrals = Evaluator::precompute_cached_integrals_at_load(
673            &expression_ir,
674            &amplitudes,
675            &self.registry.amplitude_use_sites,
676            &resources,
677            dataset,
678            parameter_map.free().len(),
679        )?;
680        let initial_cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
681        let lowering_start = Instant::now();
682        let lowered_artifacts = Arc::new(Evaluator::lower_expression_runtime_artifacts(
683            &expression_ir,
684            &cached_integrals,
685        )?);
686        let initial_lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
687        let execution_sets = expression_ir.normalization_execution_sets().clone();
688        let cached_integral_key =
689            Evaluator::cached_integral_cache_key(resources.active.clone(), dataset);
690        let cached_integral_state = Arc::new(CachedIntegralCacheState {
691            key: cached_integral_key.clone(),
692            expression_ir,
693            values: cached_integrals,
694            execution_sets,
695        });
696        let specialization_state = ExpressionSpecializationState {
697            cached_integrals: cached_integral_state.clone(),
698            lowered_artifacts: lowered_artifacts.clone(),
699        };
700        let specialization_cache = HashMap::from([(cached_integral_key, specialization_state)]);
701        let lowered_artifact_cache =
702            HashMap::from([(resources.active.clone(), lowered_artifacts.clone())]);
703        Ok(Evaluator {
704            amplitudes,
705            amplitude_use_sites: self.registry.amplitude_use_sites.clone(),
706            resources: Arc::new(RwLock::new(resources)),
707            dataset: dataset.clone(),
708            expression: self.tree.clone(),
709            ir_planning: ExpressionIrPlanningState {
710                cached_integrals: Arc::new(RwLock::new(Some(cached_integral_state))),
711                specialization_cache: Arc::new(RwLock::new(specialization_cache)),
712                specialization_metrics: Arc::new(RwLock::new(ExpressionSpecializationMetrics {
713                    cache_hits: 0,
714                    cache_misses: 1,
715                })),
716                lowered_artifact_cache: Arc::new(RwLock::new(lowered_artifact_cache)),
717                active_lowered_artifacts: Arc::new(RwLock::new(Some(lowered_artifacts.clone()))),
718                specialization_status: Arc::new(RwLock::new(Some(
719                    ExpressionSpecializationStatus {
720                        origin: ExpressionSpecializationOrigin::InitialLoad,
721                    },
722                ))),
723                compile_metrics: Arc::new(RwLock::new(ExpressionCompileMetrics {
724                    initial_ir_compile_nanos,
725                    initial_cached_integrals_nanos,
726                    initial_lowering_nanos,
727                    specialization_lowering_cache_misses: 1,
728                    ..Default::default()
729                })),
730            },
731            registry: self.registry.clone(),
732        })
733    }
734
735    /// Takes the real part of the given [`Expression`].
736    pub fn real(&self) -> Self {
737        Self::unary_op(self, ExpressionNode::Real)
738    }
739    /// Takes the imaginary part of the given [`Expression`].
740    pub fn imag(&self) -> Self {
741        Self::unary_op(self, ExpressionNode::Imag)
742    }
743    /// Takes the complex conjugate of the given [`Expression`].
744    pub fn conj(&self) -> Self {
745        Self::unary_op(self, ExpressionNode::Conj)
746    }
747    /// Takes the absolute square of the given [`Expression`].
748    pub fn norm_sqr(&self) -> Self {
749        Self::unary_op(self, ExpressionNode::NormSqr)
750    }
751    /// Takes the square root of the given [`Expression`].
752    pub fn sqrt(&self) -> Self {
753        Self::unary_op(self, ExpressionNode::Sqrt)
754    }
755    /// Raises the given [`Expression`] to an expression-valued power.
756    pub fn pow(&self, power: &Expression) -> Self {
757        Self::binary_op(self, power, ExpressionNode::Pow)
758    }
759    /// Raises the given [`Expression`] to an integer power.
760    pub fn powi(&self, power: i32) -> Self {
761        Self::unary_op(self, |input| ExpressionNode::PowI(input, power))
762    }
763    /// Raises the given [`Expression`] to a real-valued power.
764    pub fn powf(&self, power: f64) -> Self {
765        Self::unary_op(self, |input| ExpressionNode::PowF(input, power))
766    }
767    /// Takes the exponential of the given [`Expression`].
768    pub fn exp(&self) -> Self {
769        Self::unary_op(self, ExpressionNode::Exp)
770    }
771    /// Takes the sine of the given [`Expression`].
772    pub fn sin(&self) -> Self {
773        Self::unary_op(self, ExpressionNode::Sin)
774    }
775    /// Takes the cosine of the given [`Expression`].
776    pub fn cos(&self) -> Self {
777        Self::unary_op(self, ExpressionNode::Cos)
778    }
779    /// Takes the natural logarithm of the given [`Expression`].
780    pub fn log(&self) -> Self {
781        Self::unary_op(self, ExpressionNode::Log)
782    }
783    /// Takes the complex phase factor exp(i * expression).
784    pub fn cis(&self) -> Self {
785        Self::unary_op(self, ExpressionNode::Cis)
786    }
787
788    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
789    fn write_tree(
790        &self,
791        t: &ExpressionNode,
792        f: &mut std::fmt::Formatter<'_>,
793        parent_prefix: &str,
794        immediate_prefix: &str,
795        parent_suffix: &str,
796    ) -> std::fmt::Result {
797        let display_string = match t {
798            ExpressionNode::Amp(idx) => {
799                let name = self
800                    .registry
801                    .amplitude_names
802                    .get(*idx)
803                    .cloned()
804                    .unwrap_or_else(|| "<unregistered>".to_string());
805                format!("{name}(id={idx})")
806            }
807            ExpressionNode::Add(_, _) => "+".to_string(),
808            ExpressionNode::Sub(_, _) => "-".to_string(),
809            ExpressionNode::Mul(_, _) => "×".to_string(),
810            ExpressionNode::Div(_, _) => "÷".to_string(),
811            ExpressionNode::Neg(_) => "-".to_string(),
812            ExpressionNode::Real(_) => "Re".to_string(),
813            ExpressionNode::Imag(_) => "Im".to_string(),
814            ExpressionNode::Conj(_) => "*".to_string(),
815            ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
816            ExpressionNode::Zero => "0 (exact)".to_string(),
817            ExpressionNode::One => "1 (exact)".to_string(),
818            ExpressionNode::Constant(v) => v.to_string(),
819            ExpressionNode::ComplexConstant(v) => v.to_string(),
820            ExpressionNode::Sqrt(_) => "Sqrt".to_string(),
821            ExpressionNode::Pow(_, _) => "Pow".to_string(),
822            ExpressionNode::PowI(_, power) => format!("PowI({power})"),
823            ExpressionNode::PowF(_, power) => format!("PowF({power})"),
824            ExpressionNode::Exp(_) => "Exp".to_string(),
825            ExpressionNode::Sin(_) => "Sin".to_string(),
826            ExpressionNode::Cos(_) => "Cos".to_string(),
827            ExpressionNode::Log(_) => "Log".to_string(),
828            ExpressionNode::Cis(_) => "Cis".to_string(),
829        };
830        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
831        match t {
832            ExpressionNode::Amp(_)
833            | ExpressionNode::Zero
834            | ExpressionNode::One
835            | ExpressionNode::Constant(_)
836            | ExpressionNode::ComplexConstant(_) => {}
837            ExpressionNode::Add(a, b)
838            | ExpressionNode::Sub(a, b)
839            | ExpressionNode::Mul(a, b)
840            | ExpressionNode::Div(a, b)
841            | ExpressionNode::Pow(a, b) => {
842                let terms = [a, b];
843                let mut it = terms.iter().peekable();
844                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
845                while let Some(child) = it.next() {
846                    match it.peek() {
847                        Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│  "),
848                        None => self.write_tree(child, f, &child_prefix, "└─ ", "   "),
849                    }?;
850                }
851            }
852            ExpressionNode::Neg(a)
853            | ExpressionNode::Real(a)
854            | ExpressionNode::Imag(a)
855            | ExpressionNode::Conj(a)
856            | ExpressionNode::NormSqr(a)
857            | ExpressionNode::Sqrt(a)
858            | ExpressionNode::PowI(a, _)
859            | ExpressionNode::PowF(a, _)
860            | ExpressionNode::Exp(a)
861            | ExpressionNode::Sin(a)
862            | ExpressionNode::Cos(a)
863            | ExpressionNode::Log(a)
864            | ExpressionNode::Cis(a) => {
865                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
866                self.write_tree(a, f, &child_prefix, "└─ ", "   ")?;
867            }
868        }
869        Ok(())
870    }
871}
872
873impl Debug for Expression {
874    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875        self.write_tree(&self.tree, f, "", "", "")
876    }
877}
878
879impl Display for Expression {
880    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
881        self.write_tree(&self.tree, f, "", "", "")
882    }
883}
884
885#[rustfmt::skip]
886impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
887    Expression::binary_op(a, b, ExpressionNode::Add)
888});
889#[rustfmt::skip]
890impl_op_ex!(+ |a: &Expression, b: &f64| -> Expression {
891    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
892});
893#[rustfmt::skip]
894impl_op_ex!(+ |a: &f64, b: &Expression| -> Expression {
895    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
896});
897#[rustfmt::skip]
898impl_op_ex!(+ |a: &Expression, b: &Complex64| -> Expression {
899    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
900});
901#[rustfmt::skip]
902impl_op_ex!(+ |a: &Complex64, b: &Expression| -> Expression {
903    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
904});
905
906#[rustfmt::skip]
907impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
908    Expression::binary_op(a, b, ExpressionNode::Sub)
909});
910#[rustfmt::skip]
911impl_op_ex!(- |a: &Expression, b: &f64| -> Expression {
912    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
913});
914#[rustfmt::skip]
915impl_op_ex!(- |a: &f64, b: &Expression| -> Expression {
916    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
917});
918#[rustfmt::skip]
919impl_op_ex!(- |a: &Expression, b: &Complex64| -> Expression {
920    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
921});
922#[rustfmt::skip]
923impl_op_ex!(- |a: &Complex64, b: &Expression| -> Expression {
924    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
925});
926
927#[rustfmt::skip]
928impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
929    Expression::binary_op(a, b, ExpressionNode::Mul)
930});
931#[rustfmt::skip]
932impl_op_ex!(* |a: &Expression, b: &f64| -> Expression {
933    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
934});
935#[rustfmt::skip]
936impl_op_ex!(* |a: &f64, b: &Expression| -> Expression {
937    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
938});
939#[rustfmt::skip]
940impl_op_ex!(* |a: &Expression, b: &Complex64| -> Expression {
941    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
942});
943#[rustfmt::skip]
944impl_op_ex!(* |a: &Complex64, b: &Expression| -> Expression {
945    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
946});
947
948#[rustfmt::skip]
949impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
950    Expression::binary_op(a, b, ExpressionNode::Div)
951});
952#[rustfmt::skip]
953impl_op_ex!(/ |a: &Expression, b: &f64| -> Expression {
954    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
955});
956#[rustfmt::skip]
957impl_op_ex!(/ |a: &f64, b: &Expression| -> Expression {
958    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
959});
960#[rustfmt::skip]
961impl_op_ex!(/ |a: &Expression, b: &Complex64| -> Expression {
962    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
963});
964#[rustfmt::skip]
965impl_op_ex!(/ |a: &Complex64, b: &Expression| -> Expression {
966    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
967});
968
969#[rustfmt::skip]
970impl_op_ex!(- |a: &Expression| -> Expression {
971    Expression::unary_op(a, ExpressionNode::Neg)
972});
973// NOTE: no need to add an impl for negating f64 or complex!
974
975#[derive(Clone, Debug)]
976#[doc(hidden)]
977pub struct ExpressionValueProgramSnapshot {
978    lowered_program: lowered::LoweredProgram,
979}
980
981#[derive(Clone, Debug, PartialEq)]
982/// A node in a compiled expression diagnostic snapshot.
983pub enum CompiledExpressionNode {
984    /// A complex constant.
985    Constant(Complex64),
986    /// A registered amplitude use-site by index and display label.
987    Amplitude {
988        /// The amplitude index used by the compiled evaluator.
989        index: usize,
990        /// The registered amplitude display label.
991        name: String,
992    },
993    /// A unary operation and its input node.
994    Unary {
995        /// The display label for the operation.
996        op: String,
997        /// The input node index.
998        input: usize,
999    },
1000    /// A binary operation and its input nodes.
1001    Binary {
1002        /// The display label for the operation.
1003        op: String,
1004        /// The left input node index.
1005        left: usize,
1006        /// The right input node index.
1007        right: usize,
1008    },
1009}
1010
1011#[derive(Clone, Debug, PartialEq)]
1012/// Tree-like diagnostic view of the compiled expression DAG.
1013///
1014/// The compiled expression is a directed acyclic graph because common subexpressions can be
1015/// deduplicated during compilation. The display format expands the graph from the root once and
1016/// marks later visits to the same node with `(ref)`.
1017pub struct CompiledExpression {
1018    nodes: Vec<CompiledExpressionNode>,
1019    root: usize,
1020}
1021
1022impl CompiledExpression {
1023    fn from_ir(ir: &ir::ExpressionIR, amplitude_names: &[String]) -> Self {
1024        let nodes = ir
1025            .nodes()
1026            .iter()
1027            .map(|node| match node {
1028                ir::IrNode::Constant(value) => CompiledExpressionNode::Constant(*value),
1029                ir::IrNode::Amp(index) => CompiledExpressionNode::Amplitude {
1030                    index: *index,
1031                    name: amplitude_names
1032                        .get(*index)
1033                        .cloned()
1034                        .unwrap_or_else(|| "<unregistered>".to_string()),
1035                },
1036                ir::IrNode::Unary { op, input } => CompiledExpressionNode::Unary {
1037                    op: compiled_unary_op_label(*op),
1038                    input: *input,
1039                },
1040                ir::IrNode::Binary { op, left, right } => CompiledExpressionNode::Binary {
1041                    op: compiled_binary_op_label(*op),
1042                    left: *left,
1043                    right: *right,
1044                },
1045            })
1046            .collect();
1047        Self {
1048            nodes,
1049            root: ir.root(),
1050        }
1051    }
1052
1053    /// Returns the compiled expression node list in evaluator execution order.
1054    pub fn nodes(&self) -> &[CompiledExpressionNode] {
1055        &self.nodes
1056    }
1057
1058    /// Returns the root node index.
1059    pub fn root(&self) -> usize {
1060        self.root
1061    }
1062
1063    fn node_label(&self, index: usize) -> String {
1064        let Some(node) = self.nodes.get(index) else {
1065            return format!("#{index} <missing>");
1066        };
1067        let label = match node {
1068            CompiledExpressionNode::Constant(value) => format!("const {value}"),
1069            CompiledExpressionNode::Amplitude { index, name } => {
1070                format!("{name}(id={index})")
1071            }
1072            CompiledExpressionNode::Unary { op, .. }
1073            | CompiledExpressionNode::Binary { op, .. } => op.clone(),
1074        };
1075        format!("#{index} {label}")
1076    }
1077
1078    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
1079    fn write_tree(
1080        &self,
1081        index: usize,
1082        f: &mut std::fmt::Formatter<'_>,
1083        parent_prefix: &str,
1084        immediate_prefix: &str,
1085        parent_suffix: &str,
1086        expanded: &mut [bool],
1087    ) -> std::fmt::Result {
1088        let already_expanded = expanded.get(index).copied().unwrap_or(false);
1089        if let Some(slot) = expanded.get_mut(index) {
1090            *slot = true;
1091        }
1092        let ref_suffix = if already_expanded { " (ref)" } else { "" };
1093        writeln!(
1094            f,
1095            "{}{}{}{}",
1096            parent_prefix,
1097            immediate_prefix,
1098            self.node_label(index),
1099            ref_suffix
1100        )?;
1101        if already_expanded {
1102            return Ok(());
1103        }
1104        let Some(node) = self.nodes.get(index) else {
1105            return Ok(());
1106        };
1107        match node {
1108            CompiledExpressionNode::Constant(_) | CompiledExpressionNode::Amplitude { .. } => {}
1109            CompiledExpressionNode::Unary { input, .. } => {
1110                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1111                self.write_tree(*input, f, &child_prefix, "└─ ", "   ", expanded)?;
1112            }
1113            CompiledExpressionNode::Binary { left, right, .. } => {
1114                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1115                self.write_tree(*left, f, &child_prefix, "├─ ", "│  ", expanded)?;
1116                self.write_tree(*right, f, &child_prefix, "└─ ", "   ", expanded)?;
1117            }
1118        }
1119        Ok(())
1120    }
1121}
1122
1123impl Display for CompiledExpression {
1124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1125        if self.nodes.is_empty() {
1126            return writeln!(f, "<empty>");
1127        }
1128        let mut expanded = vec![false; self.nodes.len()];
1129        self.write_tree(self.root, f, "", "", "", &mut expanded)
1130    }
1131}
1132
1133fn compiled_unary_op_label(op: ir::IrUnaryOp) -> String {
1134    match op {
1135        ir::IrUnaryOp::Neg => "-".to_string(),
1136        ir::IrUnaryOp::Real => "Re".to_string(),
1137        ir::IrUnaryOp::Imag => "Im".to_string(),
1138        ir::IrUnaryOp::Conj => "*".to_string(),
1139        ir::IrUnaryOp::NormSqr => "NormSqr".to_string(),
1140        ir::IrUnaryOp::Sqrt => "Sqrt".to_string(),
1141        ir::IrUnaryOp::PowI(power) => format!("PowI({power})"),
1142        ir::IrUnaryOp::PowF(bits) => format!("PowF({})", f64::from_bits(bits)),
1143        ir::IrUnaryOp::Exp => "Exp".to_string(),
1144        ir::IrUnaryOp::Sin => "Sin".to_string(),
1145        ir::IrUnaryOp::Cos => "Cos".to_string(),
1146        ir::IrUnaryOp::Log => "Log".to_string(),
1147        ir::IrUnaryOp::Cis => "Cis".to_string(),
1148    }
1149}
1150
1151fn compiled_binary_op_label(op: ir::IrBinaryOp) -> String {
1152    match op {
1153        ir::IrBinaryOp::Add => "+".to_string(),
1154        ir::IrBinaryOp::Sub => "-".to_string(),
1155        ir::IrBinaryOp::Mul => "×".to_string(),
1156        ir::IrBinaryOp::Div => "÷".to_string(),
1157        ir::IrBinaryOp::Pow => "Pow".to_string(),
1158    }
1159}
1160#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1161/// Origin of the currently installed expression specialization.
1162pub enum ExpressionSpecializationOrigin {
1163    /// The specialization installed during evaluator construction.
1164    InitialLoad,
1165    /// The specialization was rebuilt because no cached entry matched the current state.
1166    CacheMissRebuild,
1167    /// The specialization was restored from an existing cache entry.
1168    CacheHitRestore,
1169}
1170#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1171/// Current specialization status for the evaluator runtime.
1172pub struct ExpressionSpecializationStatus {
1173    /// How the active specialization was obtained most recently.
1174    pub origin: ExpressionSpecializationOrigin,
1175}
1176#[derive(Clone, Debug, PartialEq, Eq)]
1177/// Diagnostic snapshot of the active runtime specialization state.
1178pub struct ExpressionRuntimeDiagnostics {
1179    /// Whether IR planning state is present for the evaluator.
1180    pub ir_planning_enabled: bool,
1181    /// Whether a lowered value-only program is available for the active specialization.
1182    pub lowered_value_program_present: bool,
1183    /// Whether a lowered gradient-only program is available for the active specialization.
1184    pub lowered_gradient_program_present: bool,
1185    /// Whether a lowered fused value+gradient program is available for the active specialization.
1186    pub lowered_value_gradient_program_present: bool,
1187    /// Number of cached parameter-factor descriptors in the active specialization.
1188    pub cached_parameter_factor_count: usize,
1189    /// Number of cached parameter factors with lowered runtimes available.
1190    pub lowered_cached_parameter_factor_count: usize,
1191    /// Whether a lowered residual normalization runtime is available.
1192    pub residual_runtime_present: bool,
1193    /// Number of cached specialization entries currently retained.
1194    pub specialization_cache_entries: usize,
1195    /// Number of cached lowered-artifact entries currently retained.
1196    pub lowered_artifact_cache_entries: usize,
1197    /// Origin of the currently installed specialization, when available.
1198    pub specialization_status: Option<ExpressionSpecializationStatus>,
1199}
1200#[derive(Clone)]
1201/// IR-planning state derived from the semantic expression tree plus the current active mask.
1202///
1203/// Invariants:
1204/// - `expression_ir` is never a source of truth; it is always derived from `Evaluator::expression`.
1205/// - `cached_integrals` are specialization-dependent and must be treated as invalid once the
1206///   active mask or dataset identity changes.
1207struct ExpressionIrPlanningState {
1208    cached_integrals: Arc<RwLock<Option<Arc<CachedIntegralCacheState>>>>,
1209    specialization_cache:
1210        Arc<RwLock<HashMap<CachedIntegralCacheKey, ExpressionSpecializationState>>>,
1211    specialization_metrics: Arc<RwLock<ExpressionSpecializationMetrics>>,
1212    lowered_artifact_cache: Arc<RwLock<HashMap<Vec<bool>, Arc<LoweredArtifactCacheState>>>>,
1213    active_lowered_artifacts: Arc<RwLock<Option<Arc<LoweredArtifactCacheState>>>>,
1214    specialization_status: Arc<RwLock<Option<ExpressionSpecializationStatus>>>,
1215    compile_metrics: Arc<RwLock<ExpressionCompileMetrics>>,
1216}
1217/// Evaluator for [`Expression`] that mirrors the existing evaluator behavior.
1218#[allow(missing_docs)]
1219#[derive(Clone)]
1220pub struct Evaluator {
1221    pub amplitudes: Vec<Box<dyn Amplitude>>,
1222    amplitude_use_sites: Vec<AmplitudeUseSite>,
1223    pub resources: Arc<RwLock<Resources>>,
1224    pub dataset: Arc<Dataset>,
1225    pub expression: ExpressionNode,
1226    ir_planning: ExpressionIrPlanningState,
1227    registry: ExpressionRegistry,
1228}
1229
1230#[allow(missing_docs)]
1231impl Evaluator {
1232    /// Internal benchmarking/debug counters for specialization cache reuse.
1233    pub fn expression_specialization_metrics(&self) -> ExpressionSpecializationMetrics {
1234        *self.ir_planning.specialization_metrics.read()
1235    }
1236    /// Reset specialization cache counters while leaving cached specializations intact.
1237    pub fn reset_expression_specialization_metrics(&self) {
1238        *self.ir_planning.specialization_metrics.write() =
1239            ExpressionSpecializationMetrics::default();
1240    }
1241    /// Internal benchmarking/debug metrics for staged IR compile and lowering costs.
1242    pub fn expression_compile_metrics(&self) -> ExpressionCompileMetrics {
1243        *self.ir_planning.compile_metrics.read()
1244    }
1245    /// Internal diagnostics surface for active runtime specialization state.
1246    pub fn expression_runtime_diagnostics(&self) -> ExpressionRuntimeDiagnostics {
1247        let active_artifacts = self.active_lowered_artifacts();
1248        let cached_parameter_factor_count = self
1249            .ir_planning
1250            .cached_integrals
1251            .read()
1252            .as_ref()
1253            .map(|state| state.values.len())
1254            .unwrap_or(0);
1255        let lowered_cached_parameter_factor_count = active_artifacts
1256            .as_ref()
1257            .map(|artifacts| {
1258                artifacts
1259                    .lowered_parameter_factors
1260                    .iter()
1261                    .filter(|factor| factor.is_some())
1262                    .count()
1263            })
1264            .unwrap_or(0);
1265        let residual_runtime_present = active_artifacts
1266            .as_ref()
1267            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
1268            .is_some();
1269        ExpressionRuntimeDiagnostics {
1270            ir_planning_enabled: true,
1271            lowered_value_program_present: true,
1272            lowered_gradient_program_present: true,
1273            lowered_value_gradient_program_present: true,
1274            cached_parameter_factor_count,
1275            lowered_cached_parameter_factor_count,
1276            residual_runtime_present,
1277            specialization_cache_entries: self.ir_planning.specialization_cache.read().len(),
1278            lowered_artifact_cache_entries: self.ir_planning.lowered_artifact_cache.read().len(),
1279            specialization_status: *self.ir_planning.specialization_status.read(),
1280        }
1281    }
1282    /// Reset post-load compile/lowering counters while preserving initial-load metrics.
1283    pub fn reset_expression_compile_metrics(&self) {
1284        let mut metrics = self.ir_planning.compile_metrics.write();
1285        metrics.specialization_cache_hits = 0;
1286        metrics.specialization_cache_misses = 0;
1287        metrics.specialization_ir_compile_nanos = 0;
1288        metrics.specialization_cached_integrals_nanos = 0;
1289        metrics.specialization_lowering_nanos = 0;
1290        metrics.specialization_lowering_cache_hits = 0;
1291        metrics.specialization_lowering_cache_misses = 0;
1292        metrics.specialization_cache_restore_nanos = 0;
1293    }
1294    #[cfg(test)]
1295    fn expression_ir(&self) -> ir::ExpressionIR {
1296        self.ir_planning
1297            .cached_integrals
1298            .read()
1299            .as_ref()
1300            .map(|state| state.expression_ir.clone())
1301            .expect("cached integral state should exist for evaluator IR access")
1302    }
1303    fn lowered_runtime(&self) -> lowered::LoweredExpressionRuntime {
1304        self.active_lowered_artifacts()
1305            .expect("active lowered artifacts should exist for the current specialization")
1306            .lowered_runtime
1307            .clone()
1308    }
1309    fn active_lowered_artifacts(&self) -> Option<Arc<LoweredArtifactCacheState>> {
1310        self.ir_planning.active_lowered_artifacts.read().clone()
1311    }
1312    fn lowered_runtime_slot_count(&self) -> usize {
1313        let runtime = self.lowered_runtime();
1314        [
1315            runtime.value_program().scratch_slots(),
1316            runtime.gradient_program().scratch_slots(),
1317            runtime.value_gradient_program().scratch_slots(),
1318        ]
1319        .into_iter()
1320        .max()
1321        .unwrap_or(0)
1322    }
1323    fn lowered_value_runtime_slot_count(&self) -> usize {
1324        self.lowered_runtime().value_program().scratch_slots()
1325    }
1326
1327    #[doc(hidden)]
1328    pub fn expression_value_program_snapshot(&self) -> ExpressionValueProgramSnapshot {
1329        ExpressionValueProgramSnapshot {
1330            lowered_program: self.lowered_runtime().value_program().clone(),
1331        }
1332    }
1333
1334    #[doc(hidden)]
1335    pub fn expression_value_program_snapshot_for_active_mask(
1336        &self,
1337        active_mask: &[bool],
1338    ) -> LadduResult<ExpressionValueProgramSnapshot> {
1339        let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1340        let lowered_program =
1341            lowered::LoweredProgram::from_ir_value_only(&expression_ir).map_err(|error| {
1342                LadduError::Custom(format!(
1343                    "Failed to lower value-only active-mask runtime: {error:?}"
1344                ))
1345            })?;
1346        Ok(ExpressionValueProgramSnapshot { lowered_program })
1347    }
1348
1349    #[doc(hidden)]
1350    pub fn expression_value_program_snapshot_slot_count(
1351        &self,
1352        snapshot: &ExpressionValueProgramSnapshot,
1353    ) -> usize {
1354        let _ = self;
1355        snapshot.lowered_program.scratch_slots()
1356    }
1357
1358    /// Returns a tree-like diagnostic snapshot of the compiled expression for the evaluator's
1359    /// current active-amplitude mask.
1360    pub fn compiled_expression(&self) -> CompiledExpression {
1361        let expression_ir = self.compile_expression_ir_for_active_mask(&self.active_mask());
1362        CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
1363    }
1364
1365    /// Returns the expression represented by this evaluator.
1366    pub fn expression(&self) -> Expression {
1367        Expression {
1368            tree: self.expression.clone(),
1369            registry: self.registry.clone(),
1370        }
1371    }
1372    fn lowered_gradient_runtime_slot_count(&self) -> usize {
1373        self.lowered_runtime().gradient_program().scratch_slots()
1374    }
1375    fn lowered_value_gradient_runtime_slot_count(&self) -> usize {
1376        self.lowered_runtime()
1377            .value_gradient_program()
1378            .scratch_slots()
1379    }
1380
1381    fn expression_value_slot_count(&self) -> usize {
1382        self.lowered_value_runtime_slot_count()
1383    }
1384    fn expression_gradient_slot_count(&self) -> usize {
1385        self.lowered_gradient_runtime_slot_count()
1386    }
1387    fn expression_value_gradient_slot_count(&self) -> usize {
1388        self.lowered_value_gradient_runtime_slot_count()
1389    }
1390
1391    #[doc(hidden)]
1392    pub fn expression_value_gradient_slot_count_public(&self) -> usize {
1393        self.expression_value_gradient_slot_count()
1394    }
1395    #[cfg(test)]
1396    fn specialization_cache_len(&self) -> usize {
1397        self.ir_planning.specialization_cache.read().len()
1398    }
1399    #[cfg(test)]
1400    fn lowered_artifact_cache_len(&self) -> usize {
1401        self.ir_planning.lowered_artifact_cache.read().len()
1402    }
1403    fn install_expression_specialization(&self, specialization: &ExpressionSpecializationState) {
1404        debug_assert!(Self::lowered_artifact_signature_matches(
1405            &specialization.lowered_artifacts,
1406            &specialization.cached_integrals.values,
1407        ));
1408        *self.ir_planning.cached_integrals.write() = Some(specialization.cached_integrals.clone());
1409        *self.ir_planning.active_lowered_artifacts.write() =
1410            Some(specialization.lowered_artifacts.clone());
1411        debug_assert_eq!(
1412            self.active_lowered_artifacts()
1413                .as_ref()
1414                .map(|artifacts| Arc::ptr_eq(artifacts, &specialization.lowered_artifacts)),
1415            Some(true)
1416        );
1417        debug_assert_eq!(
1418            self.lowered_runtime().value_program().scratch_slots(),
1419            specialization
1420                .lowered_artifacts
1421                .lowered_runtime
1422                .value_program()
1423                .scratch_slots()
1424        );
1425    }
1426    fn lower_expression_runtime_artifacts(
1427        expression_ir: &ir::ExpressionIR,
1428        values: &[PrecomputedCachedIntegral],
1429    ) -> LadduResult<LoweredArtifactCacheState> {
1430        let parameter_node_indices = values
1431            .iter()
1432            .map(|value| value.parameter_node_index)
1433            .collect();
1434        let mul_node_indices = values.iter().map(|value| value.mul_node_index).collect();
1435        let lowered_parameter_factors = Self::lower_cached_parameter_factors(expression_ir);
1436        let residual_runtime = Self::lower_residual_runtime(expression_ir, values);
1437        let lowered_runtime = lowered::LoweredExpressionRuntime::from_ir_value_gradient(
1438            expression_ir,
1439        )
1440        .map_err(|error| {
1441            LadduError::Custom(format!(
1442                "Failed to lower expression runtime for specialized IR: {error:?}"
1443            ))
1444        })?;
1445        Ok(LoweredArtifactCacheState {
1446            parameter_node_indices,
1447            mul_node_indices,
1448            lowered_parameter_factors,
1449            residual_runtime,
1450            lowered_runtime,
1451        })
1452    }
1453    fn lowered_artifact_signature_matches(
1454        artifacts: &LoweredArtifactCacheState,
1455        values: &[PrecomputedCachedIntegral],
1456    ) -> bool {
1457        artifacts.parameter_node_indices.len() == values.len()
1458            && artifacts.mul_node_indices.len() == values.len()
1459            && artifacts
1460                .parameter_node_indices
1461                .iter()
1462                .copied()
1463                .eq(values.iter().map(|value| value.parameter_node_index))
1464            && artifacts
1465                .mul_node_indices
1466                .iter()
1467                .copied()
1468                .eq(values.iter().map(|value| value.mul_node_index))
1469    }
1470    fn build_expression_specialization(
1471        &self,
1472        resources: &Resources,
1473        key: CachedIntegralCacheKey,
1474    ) -> LadduResult<ExpressionSpecializationState> {
1475        let ir_compile_start = Instant::now();
1476        let expression_ir = self.compile_expression_ir_for_active_mask(&resources.active);
1477        let ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
1478        let cached_integrals_start = Instant::now();
1479        let values = Self::precompute_cached_integrals_at_load(
1480            &expression_ir,
1481            &self.amplitudes,
1482            &self.amplitude_use_sites,
1483            resources,
1484            &self.dataset,
1485            self.resources.read().n_free_parameters(),
1486        )?;
1487        let cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
1488        let execution_sets = expression_ir.normalization_execution_sets().clone();
1489        let active_mask_key = resources.active.clone();
1490        let cached_lowered_artifacts = {
1491            let lowered_artifact_cache = self.ir_planning.lowered_artifact_cache.read();
1492            lowered_artifact_cache
1493                .get(&active_mask_key)
1494                .cloned()
1495                .filter(|artifacts| Self::lowered_artifact_signature_matches(artifacts, &values))
1496        };
1497        let lowered_artifacts = if let Some(artifacts) = cached_lowered_artifacts {
1498            self.ir_planning
1499                .compile_metrics
1500                .write()
1501                .specialization_lowering_cache_hits += 1;
1502            artifacts
1503        } else {
1504            let lowering_start = Instant::now();
1505            let artifacts = Arc::new(
1506                Self::lower_expression_runtime_artifacts(&expression_ir, &values)
1507                    .expect("specialized lowered runtime should build"),
1508            );
1509            let lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
1510            self.ir_planning
1511                .lowered_artifact_cache
1512                .write()
1513                .insert(active_mask_key, artifacts.clone());
1514            let mut compile_metrics = self.ir_planning.compile_metrics.write();
1515            compile_metrics.specialization_lowering_cache_misses += 1;
1516            compile_metrics.specialization_lowering_nanos += lowering_nanos;
1517            artifacts
1518        };
1519        let mut compile_metrics = self.ir_planning.compile_metrics.write();
1520        compile_metrics.specialization_cache_misses += 1;
1521        compile_metrics.specialization_ir_compile_nanos += ir_compile_nanos;
1522        compile_metrics.specialization_cached_integrals_nanos += cached_integrals_nanos;
1523        Ok(ExpressionSpecializationState {
1524            cached_integrals: Arc::new(CachedIntegralCacheState {
1525                key,
1526                expression_ir,
1527                values,
1528                execution_sets,
1529            }),
1530            lowered_artifacts,
1531        })
1532    }
1533    fn ensure_expression_specialization(
1534        &self,
1535        resources: &Resources,
1536    ) -> LadduResult<ExpressionSpecializationState> {
1537        let key = Self::cached_integral_cache_key(resources.active.clone(), &self.dataset);
1538        if let Some(state) = self.ir_planning.cached_integrals.read().as_ref() {
1539            if state.key == key {
1540                return Ok(ExpressionSpecializationState {
1541                    cached_integrals: state.clone(),
1542                    lowered_artifacts: self
1543                        .active_lowered_artifacts()
1544                        .expect("active lowered artifacts should exist for cached specialization"),
1545                });
1546            }
1547        }
1548        let cached_specialization = {
1549            let specialization_cache = self.ir_planning.specialization_cache.read();
1550            specialization_cache.get(&key).cloned()
1551        };
1552        if let Some(specialization) = cached_specialization {
1553            let restore_start = Instant::now();
1554            self.ir_planning.specialization_metrics.write().cache_hits += 1;
1555            self.install_expression_specialization(&specialization);
1556            *self.ir_planning.specialization_status.write() =
1557                Some(ExpressionSpecializationStatus {
1558                    origin: ExpressionSpecializationOrigin::CacheHitRestore,
1559                });
1560            let restore_nanos = restore_start.elapsed().as_nanos() as u64;
1561            let mut compile_metrics = self.ir_planning.compile_metrics.write();
1562            compile_metrics.specialization_cache_hits += 1;
1563            compile_metrics.specialization_cache_restore_nanos += restore_nanos;
1564            return Ok(specialization);
1565        }
1566        let specialization = self.build_expression_specialization(resources, key.clone())?;
1567        self.ir_planning.specialization_metrics.write().cache_misses += 1;
1568        self.ir_planning
1569            .specialization_cache
1570            .write()
1571            .insert(key, specialization.clone());
1572        self.install_expression_specialization(&specialization);
1573        let origin = if self.ir_planning.specialization_cache.read().len() == 1 {
1574            ExpressionSpecializationOrigin::InitialLoad
1575        } else {
1576            ExpressionSpecializationOrigin::CacheMissRebuild
1577        };
1578        *self.ir_planning.specialization_status.write() =
1579            Some(ExpressionSpecializationStatus { origin });
1580        Ok(specialization)
1581    }
1582    fn rebuild_runtime_specializations(&self, resources: &Resources) {
1583        let _ = self.ensure_expression_specialization(resources);
1584    }
1585    fn refresh_runtime_specializations(&self) {
1586        let resources = self.resources.read();
1587        self.rebuild_runtime_specializations(&resources);
1588    }
1589    fn cached_integral_cache_key(
1590        active_mask: Vec<bool>,
1591        dataset: &Dataset,
1592    ) -> CachedIntegralCacheKey {
1593        let (weights_ptr, weights_local_len) = dataset.local_weight_cache_key();
1594        CachedIntegralCacheKey {
1595            active_mask,
1596            n_events_local: dataset.n_events_local(),
1597            weights_local_len,
1598            weighted_sum_bits: dataset.n_events_weighted_local().to_bits(),
1599            weights_ptr,
1600        }
1601    }
1602    fn precompute_cached_integrals_at_load(
1603        expression_ir: &ir::ExpressionIR,
1604        amplitudes: &[Box<dyn Amplitude>],
1605        amplitude_use_sites: &[AmplitudeUseSite],
1606        resources: &Resources,
1607        dataset: &Dataset,
1608        n_free_parameters: usize,
1609    ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
1610        let descriptors = expression_ir.cached_integral_descriptors();
1611        if descriptors.is_empty() {
1612            return Ok(Vec::new());
1613        }
1614        let execution_sets = expression_ir.normalization_execution_sets();
1615        let seed_parameters = vec![0.0; n_free_parameters];
1616        let parameters = resources.parameter_map.assemble(&seed_parameters)?;
1617        let mut amplitude_values = vec![Complex64::ZERO; amplitude_use_sites.len()];
1618        let mut compute_values = vec![Complex64::ZERO; amplitudes.len()];
1619        let mut value_slots = vec![Complex64::ZERO; expression_ir.node_count()];
1620        let active_set = resources.active_indices();
1621        let cache_active_indices = execution_sets
1622            .cached_cache_amplitudes
1623            .iter()
1624            .copied()
1625            .filter(|index| active_set.binary_search(index).is_ok())
1626            .collect::<Vec<_>>();
1627        let mut weighted_cache_sums = vec![Complex64::ZERO; descriptors.len()];
1628        for (cache, event) in resources.caches.iter().zip(dataset.weights_local().iter()) {
1629            amplitude_values.fill(Complex64::ZERO);
1630            compute_values.fill(Complex64::ZERO);
1631            let mut computed = vec![false; amplitudes.len()];
1632            for &use_site_idx in &cache_active_indices {
1633                let amp_idx = amplitude_use_sites[use_site_idx].amplitude_index;
1634                if !computed[amp_idx] {
1635                    compute_values[amp_idx] = amplitudes[amp_idx].compute(&parameters, cache);
1636                    computed[amp_idx] = true;
1637                }
1638                amplitude_values[use_site_idx] = compute_values[amp_idx];
1639            }
1640            expression_ir.evaluate_into(&amplitude_values, &mut value_slots);
1641            let weight = *event;
1642            for (descriptor_index, descriptor) in descriptors.iter().enumerate() {
1643                weighted_cache_sums[descriptor_index] +=
1644                    value_slots[descriptor.cache_node_index] * weight;
1645            }
1646        }
1647        Ok(descriptors
1648            .iter()
1649            .zip(weighted_cache_sums)
1650            .map(
1651                |(descriptor, weighted_cache_sum)| PrecomputedCachedIntegral {
1652                    mul_node_index: descriptor.mul_node_index,
1653                    parameter_node_index: descriptor.parameter_node_index,
1654                    cache_node_index: descriptor.cache_node_index,
1655                    coefficient: descriptor.coefficient,
1656                    weighted_cache_sum,
1657                },
1658            )
1659            .collect())
1660    }
1661    fn lower_cached_parameter_factors(
1662        expression_ir: &ir::ExpressionIR,
1663    ) -> Vec<Option<lowered::LoweredFactorRuntime>> {
1664        expression_ir
1665            .cached_integral_descriptors()
1666            .iter()
1667            .map(|descriptor| {
1668                lowered::LoweredFactorRuntime::from_ir_root_value_gradient(
1669                    expression_ir,
1670                    descriptor.parameter_node_index,
1671                )
1672                .ok()
1673            })
1674            .collect()
1675    }
1676    fn lower_residual_runtime(
1677        expression_ir: &ir::ExpressionIR,
1678        descriptors: &[PrecomputedCachedIntegral],
1679    ) -> Option<lowered::LoweredExpressionRuntime> {
1680        let mut zeroed_nodes = vec![false; expression_ir.node_count()];
1681        for descriptor in descriptors {
1682            if descriptor.mul_node_index < zeroed_nodes.len() {
1683                zeroed_nodes[descriptor.mul_node_index] = true;
1684            }
1685        }
1686        lowered::LoweredExpressionRuntime::from_ir_zeroed_value_gradient(
1687            expression_ir,
1688            &zeroed_nodes,
1689        )
1690        .ok()
1691    }
1692
1693    #[inline]
1694    fn fill_amplitude_values(
1695        &self,
1696        amplitude_values: &mut [Complex64],
1697        active_indices: &[usize],
1698        parameters: &Parameters,
1699        cache: &Cache,
1700    ) {
1701        amplitude_values.fill(Complex64::ZERO);
1702        let mut compute_values = vec![Complex64::ZERO; self.amplitudes.len()];
1703        let mut computed = vec![false; self.amplitudes.len()];
1704        for &use_site_idx in active_indices {
1705            let amp_idx = self.amplitude_use_sites[use_site_idx].amplitude_index;
1706            if !computed[amp_idx] {
1707                compute_values[amp_idx] = self.amplitudes[amp_idx].compute(parameters, cache);
1708                computed[amp_idx] = true;
1709            }
1710            amplitude_values[use_site_idx] = compute_values[amp_idx];
1711        }
1712    }
1713
1714    #[inline]
1715    fn fill_amplitude_gradients(
1716        &self,
1717        gradient_values: &mut [DVector<Complex64>],
1718        active_mask: &[bool],
1719        parameters: &Parameters,
1720        cache: &Cache,
1721    ) {
1722        let mut compute_gradients = vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1723        let mut computed = vec![false; self.amplitudes.len()];
1724        for ((use_site, active), grad) in self
1725            .amplitude_use_sites
1726            .iter()
1727            .zip(active_mask.iter())
1728            .zip(gradient_values.iter_mut())
1729        {
1730            grad.fill(Complex64::ZERO);
1731            if *active {
1732                let amp_idx = use_site.amplitude_index;
1733                if !computed[amp_idx] {
1734                    self.amplitudes[amp_idx].compute_gradient(
1735                        parameters,
1736                        cache,
1737                        &mut compute_gradients[amp_idx],
1738                    );
1739                    computed[amp_idx] = true;
1740                }
1741                grad.copy_from(&compute_gradients[amp_idx]);
1742            }
1743        }
1744    }
1745
1746    #[inline]
1747    fn fill_amplitude_values_and_gradients(
1748        &self,
1749        amplitude_values: &mut [Complex64],
1750        gradient_values: &mut [DVector<Complex64>],
1751        active_indices: &[usize],
1752        active_mask: &[bool],
1753        parameters: &Parameters,
1754        cache: &Cache,
1755    ) {
1756        self.fill_amplitude_values(amplitude_values, active_indices, parameters, cache);
1757        self.fill_amplitude_gradients(gradient_values, active_mask, parameters, cache);
1758    }
1759
1760    #[doc(hidden)]
1761    pub fn fill_amplitude_values_and_gradients_public(
1762        &self,
1763        amplitude_values: &mut [Complex64],
1764        gradient_values: &mut [DVector<Complex64>],
1765        active_indices: &[usize],
1766        active_mask: &[bool],
1767        parameters: &Parameters,
1768        cache: &Cache,
1769    ) {
1770        self.fill_amplitude_values_and_gradients(
1771            amplitude_values,
1772            gradient_values,
1773            active_indices,
1774            active_mask,
1775            parameters,
1776            cache,
1777        );
1778    }
1779
1780    #[cfg(feature = "execution-context-prototype")]
1781    #[inline]
1782    fn evaluate_cache_gradient_with_scratch(
1783        &self,
1784        amplitude_values: &mut [Complex64],
1785        gradient_values: &mut [DVector<Complex64>],
1786        value_slots: &mut [Complex64],
1787        gradient_slots: &mut [DVector<Complex64>],
1788        active_indices: &[usize],
1789        active_mask: &[bool],
1790        parameters: &Parameters,
1791        cache: &Cache,
1792    ) -> DVector<Complex64> {
1793        self.fill_amplitude_values_and_gradients(
1794            amplitude_values,
1795            gradient_values,
1796            active_indices,
1797            active_mask,
1798            parameters,
1799            cache,
1800        );
1801        self.evaluate_expression_gradient_with_scratch(
1802            amplitude_values,
1803            gradient_values,
1804            value_slots,
1805            gradient_slots,
1806        )
1807    }
1808
1809    #[cfg(feature = "execution-context-prototype")]
1810    #[allow(dead_code)]
1811    #[inline]
1812    fn evaluate_cache_value_gradient_with_scratch(
1813        &self,
1814        amplitude_values: &mut [Complex64],
1815        gradient_values: &mut [DVector<Complex64>],
1816        value_slots: &mut [Complex64],
1817        gradient_slots: &mut [DVector<Complex64>],
1818        active_indices: &[usize],
1819        active_mask: &[bool],
1820        parameters: &Parameters,
1821        cache: &Cache,
1822    ) -> (Complex64, DVector<Complex64>) {
1823        self.fill_amplitude_values_and_gradients(
1824            amplitude_values,
1825            gradient_values,
1826            active_indices,
1827            active_mask,
1828            parameters,
1829            cache,
1830        );
1831        self.evaluate_expression_value_gradient_with_scratch(
1832            amplitude_values,
1833            gradient_values,
1834            value_slots,
1835            gradient_slots,
1836        )
1837    }
1838
1839    pub fn expression_slot_count(&self) -> usize {
1840        self.lowered_runtime_slot_count()
1841    }
1842    fn compile_expression_ir_for_active_mask(&self, active_mask: &[bool]) -> ir::ExpressionIR {
1843        let amplitude_dependencies = self
1844            .amplitude_use_sites
1845            .iter()
1846            .map(|use_site| {
1847                ir::DependenceClass::from(
1848                    self.amplitudes[use_site.amplitude_index].dependence_hint(),
1849                )
1850            })
1851            .collect::<Vec<_>>();
1852        let amplitude_realness = self
1853            .amplitude_use_sites
1854            .iter()
1855            .map(|use_site| self.amplitudes[use_site.amplitude_index].real_valued_hint())
1856            .collect::<Vec<_>>();
1857        ir::compile_expression_ir_with_real_hints(
1858            &self.expression,
1859            active_mask,
1860            &amplitude_dependencies,
1861            &amplitude_realness,
1862        )
1863    }
1864    fn lower_expression_runtime_for_active_mask(
1865        &self,
1866        active_mask: &[bool],
1867    ) -> LadduResult<lowered::LoweredExpressionRuntime> {
1868        let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1869        lowered::LoweredExpressionRuntime::from_ir_value_gradient(&expression_ir).map_err(|error| {
1870            LadduError::Custom(format!(
1871                "Failed to lower active-mask runtime specialization: {error:?}"
1872            ))
1873        })
1874    }
1875    fn ensure_cached_integral_cache_state(
1876        &self,
1877        resources: &Resources,
1878    ) -> LadduResult<Arc<CachedIntegralCacheState>> {
1879        Ok(self
1880            .ensure_expression_specialization(resources)?
1881            .cached_integrals)
1882    }
1883
1884    fn evaluate_expression_runtime_value_with_scratch(
1885        &self,
1886        amplitude_values: &[Complex64],
1887        scratch: &mut [Complex64],
1888    ) -> Complex64 {
1889        let lowered_runtime = self.lowered_runtime();
1890        lowered_runtime
1891            .value_program()
1892            .evaluate_into(amplitude_values, scratch)
1893    }
1894
1895    #[doc(hidden)]
1896    pub fn evaluate_expression_value_with_program_snapshot(
1897        &self,
1898        program_snapshot: &ExpressionValueProgramSnapshot,
1899        amplitude_values: &[Complex64],
1900        scratch: &mut [Complex64],
1901    ) -> Complex64 {
1902        program_snapshot
1903            .lowered_program
1904            .evaluate_into(amplitude_values, scratch)
1905    }
1906
1907    fn evaluate_expression_runtime_gradient_with_scratch(
1908        &self,
1909        amplitude_values: &[Complex64],
1910        gradient_values: &[DVector<Complex64>],
1911        value_scratch: &mut [Complex64],
1912        gradient_scratch: &mut [DVector<Complex64>],
1913    ) -> DVector<Complex64> {
1914        let lowered_runtime = self.lowered_runtime();
1915        lowered_runtime.gradient_program().evaluate_gradient_into(
1916            amplitude_values,
1917            gradient_values,
1918            value_scratch,
1919            gradient_scratch,
1920        )
1921    }
1922
1923    fn evaluate_expression_runtime_value_gradient_with_scratch(
1924        &self,
1925        amplitude_values: &[Complex64],
1926        gradient_values: &[DVector<Complex64>],
1927        value_scratch: &mut [Complex64],
1928        gradient_scratch: &mut [DVector<Complex64>],
1929    ) -> (Complex64, DVector<Complex64>) {
1930        let lowered_runtime = self.lowered_runtime();
1931        lowered_runtime
1932            .value_gradient_program()
1933            .evaluate_value_gradient_into(
1934                amplitude_values,
1935                gradient_values,
1936                value_scratch,
1937                gradient_scratch,
1938            )
1939    }
1940
1941    fn evaluate_expression_runtime_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
1942        let lowered_runtime = self.lowered_runtime();
1943        let program = lowered_runtime.value_program();
1944        let mut scratch = vec![Complex64::ZERO; program.scratch_slots()];
1945        program.evaluate_into(amplitude_values, &mut scratch)
1946    }
1947
1948    fn evaluate_expression_runtime_gradient(
1949        &self,
1950        amplitude_values: &[Complex64],
1951        gradient_values: &[DVector<Complex64>],
1952    ) -> DVector<Complex64> {
1953        let lowered_runtime = self.lowered_runtime();
1954        let program = lowered_runtime.gradient_program();
1955        let mut value_scratch = vec![Complex64::ZERO; program.scratch_slots()];
1956        let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
1957        let mut gradient_scratch = vec![Complex64::ZERO; program.scratch_slots() * grad_dim];
1958        program.evaluate_gradient_into_flat(
1959            amplitude_values,
1960            gradient_values,
1961            &mut value_scratch,
1962            &mut gradient_scratch,
1963            grad_dim,
1964        )
1965    }
1966    /// Dependence classification for the compiled expression root.
1967    pub fn expression_root_dependence(&self) -> LadduResult<ExpressionDependence> {
1968        let resources = self.resources.read();
1969        Ok(self
1970            .ensure_cached_integral_cache_state(&resources)?
1971            .expression_ir
1972            .root_dependence()
1973            .into())
1974    }
1975    /// Dependence classification for each compiled expression node.
1976    pub fn expression_node_dependence_annotations(&self) -> LadduResult<Vec<ExpressionDependence>> {
1977        let resources = self.resources.read();
1978        Ok(self
1979            .ensure_cached_integral_cache_state(&resources)?
1980            .expression_ir
1981            .node_dependence_annotations()
1982            .iter()
1983            .copied()
1984            .map(Into::into)
1985            .collect())
1986    }
1987    /// Warning-level diagnostics for potentially inconsistent dependence hints.
1988    pub fn expression_dependence_warnings(&self) -> LadduResult<Vec<String>> {
1989        let resources = self.resources.read();
1990        Ok(self
1991            .ensure_cached_integral_cache_state(&resources)?
1992            .expression_ir
1993            .dependence_warnings()
1994            .to_vec())
1995    }
1996    /// Explain/debug view of IR normalization planning decomposition.
1997    pub fn expression_normalization_plan_explain(&self) -> LadduResult<NormalizationPlanExplain> {
1998        let resources = self.resources.read();
1999        Ok(self
2000            .ensure_cached_integral_cache_state(&resources)?
2001            .expression_ir
2002            .normalization_plan_explain()
2003            .into())
2004    }
2005    /// Explain/debug view of amplitude execution sets used by normalization evaluation.
2006    pub fn expression_normalization_execution_sets(
2007        &self,
2008    ) -> LadduResult<NormalizationExecutionSetsExplain> {
2009        let resources = self.resources.read();
2010        Ok(self
2011            .ensure_cached_integral_cache_state(&resources)?
2012            .execution_sets
2013            .clone()
2014            .into())
2015    }
2016    /// Cached integral terms precomputed at evaluator load.
2017    pub fn expression_precomputed_cached_integrals(
2018        &self,
2019    ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
2020        let resources = self.resources.read();
2021        Ok(self
2022            .ensure_cached_integral_cache_state(&resources)?
2023            .values
2024            .clone())
2025    }
2026    /// Derivative rules for cached separable terms evaluated at the given parameter point.
2027    ///
2028    /// Each returned term corresponds to a cached separable descriptor and contributes
2029    /// `weighted_gradient` to `d(normalization)/dp` prior to residual-term combination.
2030    pub fn expression_precomputed_cached_integral_gradient_terms(
2031        &self,
2032        parameters: &[f64],
2033    ) -> LadduResult<Vec<PrecomputedCachedIntegralGradientTerm>> {
2034        let resources = self.resources.read();
2035        let state = self.ensure_cached_integral_cache_state(&resources)?;
2036        if state.values.is_empty() {
2037            return Ok(Vec::new());
2038        }
2039
2040        let Some(cache) = resources.caches.first() else {
2041            return Ok(state
2042                .values
2043                .iter()
2044                .map(|descriptor| PrecomputedCachedIntegralGradientTerm {
2045                    mul_node_index: descriptor.mul_node_index,
2046                    parameter_node_index: descriptor.parameter_node_index,
2047                    cache_node_index: descriptor.cache_node_index,
2048                    coefficient: descriptor.coefficient,
2049                    weighted_gradient: DVector::zeros(parameters.len()),
2050                })
2051                .collect());
2052        };
2053
2054        let parameter_values = resources.parameter_map.assemble(parameters)?;
2055        let mut amplitude_values = vec![Complex64::ZERO; self.amplitude_use_sites.len()];
2056        self.fill_amplitude_values(
2057            &mut amplitude_values,
2058            resources.active_indices(),
2059            &parameter_values,
2060            cache,
2061        );
2062        let mut amplitude_gradients = (0..self.amplitude_use_sites.len())
2063            .map(|_| DVector::zeros(parameters.len()))
2064            .collect::<Vec<_>>();
2065        self.fill_amplitude_gradients(
2066            &mut amplitude_gradients,
2067            &resources.active,
2068            &parameter_values,
2069            cache,
2070        );
2071        let lowered_artifacts = self.active_lowered_artifacts();
2072        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2073        let mut gradient_slots = (0..state.expression_ir.node_count())
2074            .map(|_| DVector::zeros(parameters.len()))
2075            .collect::<Vec<_>>();
2076        let max_lowered_slots = lowered_artifacts
2077            .as_ref()
2078            .map(|artifacts| {
2079                artifacts
2080                    .lowered_parameter_factors
2081                    .iter()
2082                    .filter_map(|runtime| {
2083                        runtime
2084                            .as_ref()
2085                            .and_then(|runtime| runtime.gradient_program())
2086                            .map(|program| program.scratch_slots())
2087                    })
2088                    .max()
2089                    .unwrap_or(0)
2090            })
2091            .unwrap_or(0);
2092        let mut lowered_value_slots = vec![Complex64::ZERO; max_lowered_slots];
2093        let mut lowered_gradient_slots = vec![DVector::zeros(parameters.len()); max_lowered_slots];
2094        let use_lowered = lowered_artifacts.as_ref().is_some_and(|artifacts| {
2095            artifacts.lowered_parameter_factors.len() == state.values.len()
2096                && artifacts.lowered_parameter_factors.iter().all(|runtime| {
2097                    runtime
2098                        .as_ref()
2099                        .and_then(|runtime| runtime.gradient_program())
2100                        .is_some()
2101                })
2102        });
2103
2104        if !use_lowered {
2105            let _ = state.expression_ir.evaluate_gradient_into(
2106                &amplitude_values,
2107                &amplitude_gradients,
2108                &mut value_slots,
2109                &mut gradient_slots,
2110            );
2111        }
2112
2113        if use_lowered {
2114            let lowered_artifacts = lowered_artifacts.expect("lowered artifacts should exist");
2115            Ok(state
2116                .values
2117                .iter()
2118                .cloned()
2119                .zip(lowered_artifacts.lowered_parameter_factors.iter())
2120                .map(|(descriptor, runtime)| {
2121                    let parameter_gradient = runtime
2122                        .as_ref()
2123                        .and_then(|runtime| runtime.gradient_program())
2124                        .map(|program| {
2125                            program.evaluate_gradient_into(
2126                                &amplitude_values,
2127                                &amplitude_gradients,
2128                                &mut lowered_value_slots[..program.scratch_slots()],
2129                                &mut lowered_gradient_slots[..program.scratch_slots()],
2130                            )
2131                        })
2132                        .unwrap_or_else(|| DVector::zeros(parameters.len()));
2133                    let weighted_gradient = parameter_gradient.map(|value| {
2134                        value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2135                    });
2136                    PrecomputedCachedIntegralGradientTerm {
2137                        mul_node_index: descriptor.mul_node_index,
2138                        parameter_node_index: descriptor.parameter_node_index,
2139                        cache_node_index: descriptor.cache_node_index,
2140                        coefficient: descriptor.coefficient,
2141                        weighted_gradient,
2142                    }
2143                })
2144                .collect())
2145        } else {
2146            Ok(state
2147                .values
2148                .iter()
2149                .map(|descriptor| {
2150                    let parameter_gradient = gradient_slots
2151                        .get(descriptor.parameter_node_index)
2152                        .cloned()
2153                        .unwrap_or_else(|| DVector::zeros(parameters.len()));
2154                    let weighted_gradient = parameter_gradient.map(|value| {
2155                        value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2156                    });
2157                    PrecomputedCachedIntegralGradientTerm {
2158                        mul_node_index: descriptor.mul_node_index,
2159                        parameter_node_index: descriptor.parameter_node_index,
2160                        cache_node_index: descriptor.cache_node_index,
2161                        coefficient: descriptor.coefficient,
2162                        weighted_gradient,
2163                    }
2164                })
2165                .collect())
2166        }
2167    }
2168    fn evaluate_cached_weighted_value_sum_ir(
2169        &self,
2170        state: &CachedIntegralCacheState,
2171        amplitude_values: &[Complex64],
2172    ) -> f64 {
2173        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2174        let _ = state
2175            .expression_ir
2176            .evaluate_into(amplitude_values, &mut value_slots);
2177        state
2178            .values
2179            .iter()
2180            .map(|descriptor| {
2181                let parameter_factor = value_slots[descriptor.parameter_node_index];
2182                (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2183                    .re
2184            })
2185            .sum()
2186    }
2187    fn evaluate_cached_weighted_value_sum_lowered(
2188        &self,
2189        state: &CachedIntegralCacheState,
2190        lowered_artifacts: &LoweredArtifactCacheState,
2191        amplitude_values: &[Complex64],
2192    ) -> Option<f64> {
2193        let max_slots = lowered_artifacts
2194            .lowered_parameter_factors
2195            .iter()
2196            .filter_map(|runtime| {
2197                runtime
2198                    .as_ref()
2199                    .and_then(|runtime| runtime.value_program())
2200                    .map(|program| program.scratch_slots())
2201            })
2202            .max()
2203            .unwrap_or(0);
2204        let mut value_slots = vec![Complex64::ZERO; max_slots];
2205        let mut total = 0.0;
2206        for (descriptor, runtime) in state
2207            .values
2208            .iter()
2209            .zip(lowered_artifacts.lowered_parameter_factors.iter())
2210        {
2211            let parameter_factor = runtime
2212                .as_ref()
2213                .and_then(|runtime| runtime.value_program())
2214                .map(|program| {
2215                    program.evaluate_into(
2216                        amplitude_values,
2217                        &mut value_slots[..program.scratch_slots()],
2218                    )
2219                })?;
2220            total +=
2221                (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2222                    .re;
2223        }
2224        Some(total)
2225    }
2226    fn evaluate_cached_weighted_gradient_sum_ir(
2227        &self,
2228        state: &CachedIntegralCacheState,
2229        amplitude_values: &[Complex64],
2230        amplitude_gradients: &[DVector<Complex64>],
2231        grad_dim: usize,
2232    ) -> DVector<f64> {
2233        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2234        let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2235        let _ = state.expression_ir.evaluate_gradient_into(
2236            amplitude_values,
2237            amplitude_gradients,
2238            &mut value_slots,
2239            &mut gradient_slots,
2240        );
2241        state
2242            .values
2243            .iter()
2244            .fold(DVector::zeros(grad_dim), |mut accum, descriptor| {
2245                let parameter_gradient = &gradient_slots[descriptor.parameter_node_index];
2246                let coefficient = descriptor.coefficient as f64;
2247                for (accum_item, gradient_item) in accum.iter_mut().zip(parameter_gradient.iter()) {
2248                    *accum_item +=
2249                        (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2250                }
2251                accum
2252            })
2253    }
2254    fn evaluate_cached_weighted_gradient_sum_lowered(
2255        &self,
2256        state: &CachedIntegralCacheState,
2257        lowered_artifacts: &LoweredArtifactCacheState,
2258        amplitude_values: &[Complex64],
2259        amplitude_gradients: &[DVector<Complex64>],
2260        grad_dim: usize,
2261    ) -> Option<DVector<f64>> {
2262        let max_value_slots = lowered_artifacts
2263            .lowered_parameter_factors
2264            .iter()
2265            .filter_map(|runtime| {
2266                runtime
2267                    .as_ref()
2268                    .and_then(|runtime| runtime.gradient_program())
2269                    .map(|program| program.scratch_slots())
2270            })
2271            .max()
2272            .unwrap_or(0);
2273        let mut value_slots = vec![Complex64::ZERO; max_value_slots];
2274        let mut gradient_slots = vec![Complex64::ZERO; max_value_slots * grad_dim];
2275        let mut total = DVector::zeros(grad_dim);
2276        for (descriptor, runtime) in state
2277            .values
2278            .iter()
2279            .zip(lowered_artifacts.lowered_parameter_factors.iter())
2280        {
2281            let parameter_gradient = runtime
2282                .as_ref()
2283                .and_then(|runtime| runtime.gradient_program())
2284                .map(|program| {
2285                    program.evaluate_gradient_into_flat(
2286                        amplitude_values,
2287                        amplitude_gradients,
2288                        &mut value_slots[..program.scratch_slots()],
2289                        &mut gradient_slots[..program.scratch_slots() * grad_dim],
2290                        grad_dim,
2291                    )
2292                })?;
2293            let coefficient = descriptor.coefficient as f64;
2294            for (accum_item, gradient_item) in total.iter_mut().zip(parameter_gradient.iter()) {
2295                *accum_item += (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2296            }
2297        }
2298        Some(total)
2299    }
2300    fn evaluate_residual_value_ir(
2301        &self,
2302        state: &CachedIntegralCacheState,
2303        amplitude_values: &[Complex64],
2304    ) -> Complex64 {
2305        let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2306        for descriptor in &state.values {
2307            if descriptor.mul_node_index < zeroed_nodes.len() {
2308                zeroed_nodes[descriptor.mul_node_index] = true;
2309            }
2310        }
2311        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2312        state.expression_ir.evaluate_into_with_zeroed_nodes(
2313            amplitude_values,
2314            &mut value_slots,
2315            &zeroed_nodes,
2316        )
2317    }
2318    fn evaluate_residual_gradient_ir(
2319        &self,
2320        state: &CachedIntegralCacheState,
2321        amplitude_values: &[Complex64],
2322        amplitude_gradients: &[DVector<Complex64>],
2323        grad_dim: usize,
2324    ) -> DVector<Complex64> {
2325        let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2326        for descriptor in &state.values {
2327            if descriptor.mul_node_index < zeroed_nodes.len() {
2328                zeroed_nodes[descriptor.mul_node_index] = true;
2329            }
2330        }
2331        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2332        let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2333        state
2334            .expression_ir
2335            .evaluate_gradient_into_with_zeroed_nodes(
2336                amplitude_values,
2337                amplitude_gradients,
2338                &mut value_slots,
2339                &mut gradient_slots,
2340                &zeroed_nodes,
2341            )
2342    }
2343
2344    fn evaluate_weighted_value_sum_local_components(
2345        &self,
2346        parameters: &[f64],
2347    ) -> LadduResult<(f64, f64)> {
2348        let resources = self.resources.read();
2349        let parameters = resources.parameter_map.assemble(parameters)?;
2350        let amplitude_len = self.amplitude_use_sites.len();
2351        let state = self.ensure_cached_integral_cache_state(&resources)?;
2352        let lowered_artifacts = self.active_lowered_artifacts();
2353        let residual_value_slot_count = lowered_artifacts
2354            .as_ref()
2355            .and_then(|artifacts| {
2356                artifacts
2357                    .residual_runtime
2358                    .as_ref()
2359                    .map(|runtime| runtime.value_program())
2360                    .map(|program| program.scratch_slots())
2361            })
2362            .unwrap_or_else(|| self.expression_slot_count());
2363        let residual_value_program = lowered_artifacts
2364            .as_ref()
2365            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2366            .map(|runtime| runtime.value_program());
2367        let cached_parameter_indices = &state.execution_sets.cached_parameter_amplitudes;
2368        let residual_active_indices = &state.execution_sets.residual_amplitudes;
2369        debug_assert!(cached_parameter_indices.iter().all(|&index| resources
2370            .active
2371            .get(index)
2372            .copied()
2373            .unwrap_or(false)));
2374        debug_assert!(residual_active_indices.iter().all(|&index| resources
2375            .active
2376            .get(index)
2377            .copied()
2378            .unwrap_or(false)));
2379        let cached_value_sum = {
2380            if let Some(cache) = resources.caches.first() {
2381                let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2382                self.fill_amplitude_values(
2383                    &mut amplitude_values,
2384                    cached_parameter_indices,
2385                    &parameters,
2386                    cache,
2387                );
2388                lowered_artifacts
2389                    .as_ref()
2390                    .and_then(|artifacts| {
2391                        self.evaluate_cached_weighted_value_sum_lowered(
2392                            &state,
2393                            artifacts,
2394                            &amplitude_values,
2395                        )
2396                    })
2397                    .unwrap_or_else(|| {
2398                        self.evaluate_cached_weighted_value_sum_ir(&state, &amplitude_values)
2399                    })
2400            } else {
2401                0.0
2402            }
2403        };
2404
2405        #[cfg(feature = "rayon")]
2406        let residual_sum: f64 = {
2407            resources
2408                .caches
2409                .par_iter()
2410                .zip(self.dataset.weights_local().par_iter())
2411                .map_init(
2412                    || {
2413                        (
2414                            vec![Complex64::ZERO; amplitude_len],
2415                            vec![Complex64::ZERO; residual_value_slot_count],
2416                        )
2417                    },
2418                    |(amplitude_values, value_slots), (cache, event)| {
2419                        self.fill_amplitude_values(
2420                            amplitude_values,
2421                            residual_active_indices,
2422                            &parameters,
2423                            cache,
2424                        );
2425                        {
2426                            let value = residual_value_program
2427                                .as_ref()
2428                                .map(|program| {
2429                                    program.evaluate_into(
2430                                        amplitude_values,
2431                                        &mut value_slots[..program.scratch_slots()],
2432                                    )
2433                                })
2434                                .unwrap_or_else(|| {
2435                                    self.evaluate_residual_value_ir(&state, amplitude_values)
2436                                });
2437                            *event * value.re
2438                        }
2439                    },
2440                )
2441                .sum()
2442        };
2443
2444        #[cfg(not(feature = "rayon"))]
2445        let residual_sum: f64 = {
2446            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2447            let mut value_slots = vec![Complex64::ZERO; residual_value_slot_count];
2448            resources
2449                .caches
2450                .iter()
2451                .zip(self.dataset.weights_local().iter())
2452                .map(|(cache, event)| {
2453                    self.fill_amplitude_values(
2454                        &mut amplitude_values,
2455                        &residual_active_indices,
2456                        &parameters,
2457                        cache,
2458                    );
2459                    {
2460                        let value = residual_value_program
2461                            .as_ref()
2462                            .map(|program| {
2463                                program.evaluate_into(
2464                                    &amplitude_values,
2465                                    &mut value_slots[..program.scratch_slots()],
2466                                )
2467                            })
2468                            .unwrap_or_else(|| {
2469                                self.evaluate_residual_value_ir(&state, &amplitude_values)
2470                            });
2471                        *event * value.re
2472                    }
2473                })
2474                .sum()
2475        };
2476        Ok((residual_sum, cached_value_sum))
2477    }
2478
2479    /// Weighted sum over local events of the real expression value.
2480    ///
2481    /// This returns `sum_e(weight_e * Re(L_e))`.
2482    pub fn evaluate_weighted_value_sum_local(&self, parameters: &[f64]) -> LadduResult<f64> {
2483        let (residual_sum, cached_value_sum) =
2484            self.evaluate_weighted_value_sum_local_components(parameters)?;
2485        Ok(residual_sum + cached_value_sum)
2486    }
2487
2488    #[cfg(feature = "mpi")]
2489    /// Weighted sum over all ranks of the real expression value.
2490    ///
2491    /// This returns `sum_{r,e}(weight_{r,e} * Re(L_{r,e}))`.
2492    pub fn evaluate_weighted_value_sum_mpi(
2493        &self,
2494        parameters: &[f64],
2495        world: &SimpleCommunicator,
2496    ) -> LadduResult<f64> {
2497        let (residual_sum_local, cached_value_sum_local) =
2498            self.evaluate_weighted_value_sum_local_components(parameters)?;
2499        let mut residual_sum = 0.0;
2500        world.all_reduce_into(
2501            &residual_sum_local,
2502            &mut residual_sum,
2503            mpi::collective::SystemOperation::sum(),
2504        );
2505        let mut cached_value_sum = 0.0;
2506        world.all_reduce_into(
2507            &cached_value_sum_local,
2508            &mut cached_value_sum,
2509            mpi::collective::SystemOperation::sum(),
2510        );
2511        Ok(residual_sum + cached_value_sum)
2512    }
2513
2514    /// Weighted sum over local events of the real gradient of the expression.
2515    ///
2516    /// This returns `sum_e(weight_e * Re(dL_e/dp))` for all free parameters.
2517    fn evaluate_weighted_gradient_sum_local_components(
2518        &self,
2519        parameters: &[f64],
2520    ) -> LadduResult<(DVector<f64>, DVector<f64>)> {
2521        let resources = self.resources.read();
2522        let parameters = resources.parameter_map.assemble(parameters)?;
2523        let amplitude_len = self.amplitude_use_sites.len();
2524        let grad_dim = parameters.len();
2525        let state = self.ensure_cached_integral_cache_state(&resources)?;
2526        let lowered_artifacts = self.active_lowered_artifacts();
2527        let active_index_set = resources.active_indices();
2528        let cached_parameter_indices = state
2529            .execution_sets
2530            .cached_parameter_amplitudes
2531            .iter()
2532            .copied()
2533            .filter(|index| active_index_set.binary_search(index).is_ok())
2534            .collect::<Vec<_>>();
2535        let residual_active_indices = state
2536            .execution_sets
2537            .residual_amplitudes
2538            .iter()
2539            .copied()
2540            .filter(|index| active_index_set.binary_search(index).is_ok())
2541            .collect::<Vec<_>>();
2542        let mut cached_parameter_mask = vec![false; amplitude_len];
2543        for &index in &cached_parameter_indices {
2544            cached_parameter_mask[index] = true;
2545        }
2546        let mut residual_active_mask = vec![false; amplitude_len];
2547        for &index in &residual_active_indices {
2548            residual_active_mask[index] = true;
2549        }
2550        let residual_gradient_program = lowered_artifacts
2551            .as_ref()
2552            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2553            .map(|runtime| runtime.gradient_program());
2554        let residual_gradient_slot_count = residual_gradient_program
2555            .as_ref()
2556            .map(|program| program.scratch_slots())
2557            .unwrap_or_else(|| state.expression_ir.node_count());
2558        let cached_term_sum = {
2559            if let Some(cache) = resources.caches.first() {
2560                let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2561                self.fill_amplitude_values(
2562                    &mut amplitude_values,
2563                    &cached_parameter_indices,
2564                    &parameters,
2565                    cache,
2566                );
2567                let mut amplitude_gradients = (0..amplitude_len)
2568                    .map(|_| DVector::zeros(grad_dim))
2569                    .collect::<Vec<_>>();
2570                self.fill_amplitude_gradients(
2571                    &mut amplitude_gradients,
2572                    &cached_parameter_mask,
2573                    &parameters,
2574                    cache,
2575                );
2576                lowered_artifacts
2577                    .as_ref()
2578                    .and_then(|artifacts| {
2579                        self.evaluate_cached_weighted_gradient_sum_lowered(
2580                            &state,
2581                            artifacts,
2582                            &amplitude_values,
2583                            &amplitude_gradients,
2584                            grad_dim,
2585                        )
2586                    })
2587                    .unwrap_or_else(|| {
2588                        self.evaluate_cached_weighted_gradient_sum_ir(
2589                            &state,
2590                            &amplitude_values,
2591                            &amplitude_gradients,
2592                            grad_dim,
2593                        )
2594                    })
2595            } else {
2596                DVector::zeros(grad_dim)
2597            }
2598        };
2599
2600        #[cfg(feature = "rayon")]
2601        let residual_sum = {
2602            resources
2603                .caches
2604                .par_iter()
2605                .zip(self.dataset.weights_local().par_iter())
2606                .map_init(
2607                    || {
2608                        (
2609                            vec![Complex64::ZERO; amplitude_len],
2610                            vec![DVector::zeros(grad_dim); amplitude_len],
2611                            vec![Complex64::ZERO; residual_gradient_slot_count],
2612                            vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim],
2613                        )
2614                    },
2615                    |(amplitude_values, gradient_values, value_slots, gradient_slots),
2616                     (cache, event)| {
2617                        self.fill_amplitude_values_and_gradients(
2618                            amplitude_values,
2619                            gradient_values,
2620                            &residual_active_indices,
2621                            &residual_active_mask,
2622                            &parameters,
2623                            cache,
2624                        );
2625                        let gradient = residual_gradient_program
2626                            .as_ref()
2627                            .map(|program| {
2628                                program.evaluate_gradient_into_flat(
2629                                    amplitude_values,
2630                                    gradient_values,
2631                                    value_slots,
2632                                    gradient_slots,
2633                                    grad_dim,
2634                                )
2635                            })
2636                            .unwrap_or_else(|| {
2637                                self.evaluate_residual_gradient_ir(
2638                                    &state,
2639                                    amplitude_values,
2640                                    gradient_values,
2641                                    grad_dim,
2642                                )
2643                            });
2644                        gradient.map(|value| value.re).scale(*event)
2645                    },
2646                )
2647                .reduce(
2648                    || DVector::zeros(grad_dim),
2649                    |mut accum, value| {
2650                        accum += value;
2651                        accum
2652                    },
2653                )
2654        };
2655
2656        #[cfg(not(feature = "rayon"))]
2657        let residual_sum = {
2658            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2659            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
2660            let mut value_slots = vec![Complex64::ZERO; residual_gradient_slot_count];
2661            let mut gradient_slots = vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim];
2662            resources
2663                .caches
2664                .iter()
2665                .zip(self.dataset.weights_local().iter())
2666                .map(|(cache, event)| {
2667                    self.fill_amplitude_values_and_gradients(
2668                        &mut amplitude_values,
2669                        &mut gradient_values,
2670                        &residual_active_indices,
2671                        &residual_active_mask,
2672                        &parameters,
2673                        cache,
2674                    );
2675                    let gradient = residual_gradient_program
2676                        .as_ref()
2677                        .map(|program| {
2678                            program.evaluate_gradient_into_flat(
2679                                &amplitude_values,
2680                                &gradient_values,
2681                                &mut value_slots,
2682                                &mut gradient_slots,
2683                                grad_dim,
2684                            )
2685                        })
2686                        .unwrap_or_else(|| {
2687                            self.evaluate_residual_gradient_ir(
2688                                &state,
2689                                &amplitude_values,
2690                                &gradient_values,
2691                                grad_dim,
2692                            )
2693                        });
2694                    gradient.map(|value| value.re).scale(*event)
2695                })
2696                .sum()
2697        };
2698        Ok((residual_sum, cached_term_sum))
2699    }
2700
2701    /// Weighted sum over local events of the real gradient of the expression.
2702    ///
2703    /// This returns `sum_e(weight_e * Re(dL_e/dp))` for all free parameters.
2704    pub fn evaluate_weighted_gradient_sum_local(
2705        &self,
2706        parameters: &[f64],
2707    ) -> LadduResult<DVector<f64>> {
2708        let (residual_sum, cached_term_sum) =
2709            self.evaluate_weighted_gradient_sum_local_components(parameters)?;
2710        Ok(residual_sum + cached_term_sum)
2711    }
2712
2713    #[cfg(feature = "mpi")]
2714    /// Weighted sum over all ranks of the real gradient of the expression.
2715    ///
2716    /// This returns `sum_{r,e}(weight_{r,e} * Re(dL_{r,e}/dp))`.
2717    pub fn evaluate_weighted_gradient_sum_mpi(
2718        &self,
2719        parameters: &[f64],
2720        world: &SimpleCommunicator,
2721    ) -> LadduResult<DVector<f64>> {
2722        let (residual_sum_local, cached_term_sum_local) =
2723            self.evaluate_weighted_gradient_sum_local_components(parameters)?;
2724        let mut residual_sum = vec![0.0; residual_sum_local.len()];
2725        world.all_reduce_into(
2726            residual_sum_local.as_slice(),
2727            &mut residual_sum,
2728            mpi::collective::SystemOperation::sum(),
2729        );
2730        let mut cached_term_sum = vec![0.0; cached_term_sum_local.len()];
2731        world.all_reduce_into(
2732            cached_term_sum_local.as_slice(),
2733            &mut cached_term_sum,
2734            mpi::collective::SystemOperation::sum(),
2735        );
2736        let mut total = DVector::from_vec(residual_sum);
2737        total += DVector::from_vec(cached_term_sum);
2738        Ok(total)
2739    }
2740
2741    pub fn evaluate_expression_value_with_scratch(
2742        &self,
2743        amplitude_values: &[Complex64],
2744        scratch: &mut [Complex64],
2745    ) -> Complex64 {
2746        self.evaluate_expression_runtime_value_with_scratch(amplitude_values, scratch)
2747    }
2748
2749    pub fn evaluate_expression_gradient_with_scratch(
2750        &self,
2751        amplitude_values: &[Complex64],
2752        gradient_values: &[DVector<Complex64>],
2753        value_scratch: &mut [Complex64],
2754        gradient_scratch: &mut [DVector<Complex64>],
2755    ) -> DVector<Complex64> {
2756        self.evaluate_expression_runtime_gradient_with_scratch(
2757            amplitude_values,
2758            gradient_values,
2759            value_scratch,
2760            gradient_scratch,
2761        )
2762    }
2763
2764    pub fn evaluate_expression_value_gradient_with_scratch(
2765        &self,
2766        amplitude_values: &[Complex64],
2767        gradient_values: &[DVector<Complex64>],
2768        value_scratch: &mut [Complex64],
2769        gradient_scratch: &mut [DVector<Complex64>],
2770    ) -> (Complex64, DVector<Complex64>) {
2771        self.evaluate_expression_runtime_value_gradient_with_scratch(
2772            amplitude_values,
2773            gradient_values,
2774            value_scratch,
2775            gradient_scratch,
2776        )
2777    }
2778
2779    pub fn evaluate_expression_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
2780        self.evaluate_expression_runtime_value(amplitude_values)
2781    }
2782
2783    pub fn evaluate_expression_gradient(
2784        &self,
2785        amplitude_values: &[Complex64],
2786        gradient_values: &[DVector<Complex64>],
2787    ) -> DVector<Complex64> {
2788        self.evaluate_expression_runtime_gradient(amplitude_values, gradient_values)
2789    }
2790
2791    /// Get the parameters used by this evaluator.
2792    pub fn parameters(&self) -> ParameterMap {
2793        self.resources.read().parameters()
2794    }
2795
2796    /// Number of free parameters.
2797    pub fn n_free(&self) -> usize {
2798        self.resources.read().n_free_parameters()
2799    }
2800
2801    /// Number of fixed parameters.
2802    pub fn n_fixed(&self) -> usize {
2803        self.resources.read().n_fixed_parameters()
2804    }
2805
2806    /// Total number of parameters.
2807    pub fn n_parameters(&self) -> usize {
2808        self.resources.read().n_parameters()
2809    }
2810
2811    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
2812        self.resources.read().fix_parameter(name, value)
2813    }
2814
2815    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
2816        self.resources.read().free_parameter(name)
2817    }
2818
2819    pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
2820        self.resources.write().rename_parameter(old, new)
2821    }
2822
2823    pub fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
2824        self.resources.write().rename_parameters(mapping)
2825    }
2826
2827    /// Activate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2828    pub fn activate<T: AsRef<str>>(&self, name: T) {
2829        self.resources.write().activate(name);
2830        self.refresh_runtime_specializations();
2831    }
2832    /// Activate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
2833    pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2834        self.resources.write().activate_strict(name)?;
2835        self.refresh_runtime_specializations();
2836        Ok(())
2837    }
2838
2839    /// Activate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2840    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
2841        self.resources.write().activate_many(names);
2842        self.refresh_runtime_specializations();
2843    }
2844    /// Activate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
2845    pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2846        self.resources.write().activate_many_strict(names)?;
2847        self.refresh_runtime_specializations();
2848        Ok(())
2849    }
2850
2851    /// Activate all registered [`Amplitude`]s.
2852    pub fn activate_all(&self) {
2853        self.resources.write().activate_all();
2854        self.refresh_runtime_specializations();
2855    }
2856
2857    /// Deactivate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2858    pub fn deactivate<T: AsRef<str>>(&self, name: T) {
2859        self.resources.write().deactivate(name);
2860        self.refresh_runtime_specializations();
2861    }
2862
2863    /// Deactivate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
2864    pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2865        self.resources.write().deactivate_strict(name)?;
2866        self.refresh_runtime_specializations();
2867        Ok(())
2868    }
2869
2870    /// Deactivate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2871    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
2872        self.resources.write().deactivate_many(names);
2873        self.refresh_runtime_specializations();
2874    }
2875    /// Deactivate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
2876    pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2877        self.resources.write().deactivate_many_strict(names)?;
2878        self.refresh_runtime_specializations();
2879        Ok(())
2880    }
2881
2882    /// Deactivate all tagged [`Amplitude`] use-sites.
2883    pub fn deactivate_all(&self) {
2884        self.resources.write().deactivate_all();
2885        self.refresh_runtime_specializations();
2886    }
2887
2888    /// Isolate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2889    pub fn isolate<T: AsRef<str>>(&self, name: T) {
2890        self.resources.write().isolate(name);
2891        self.refresh_runtime_specializations();
2892    }
2893
2894    /// Isolate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
2895    pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2896        self.resources.write().isolate_strict(name)?;
2897        self.refresh_runtime_specializations();
2898        Ok(())
2899    }
2900
2901    /// Isolate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
2902    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
2903        self.resources.write().isolate_many(names);
2904        self.refresh_runtime_specializations();
2905    }
2906
2907    /// Isolate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
2908    pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2909        self.resources.write().isolate_many_strict(names)?;
2910        self.refresh_runtime_specializations();
2911        Ok(())
2912    }
2913
2914    /// Return a copy of the current active-amplitude mask.
2915    pub fn active_mask(&self) -> Vec<bool> {
2916        self.resources.read().active.clone()
2917    }
2918
2919    /// Apply a precomputed active-amplitude mask. Untagged use-sites cannot be deactivated.
2920    pub fn set_active_mask(&self, mask: &[bool]) -> LadduResult<()> {
2921        let resources = {
2922            let mut resources = self.resources.write();
2923            if mask.len() != resources.active.len() {
2924                return Err(LadduError::LengthMismatch {
2925                    context: "active amplitude mask".to_string(),
2926                    expected: resources.active.len(),
2927                    actual: mask.len(),
2928                });
2929            }
2930            resources.apply_active_mask(mask)?;
2931            resources.clone()
2932        };
2933        self.rebuild_runtime_specializations(&resources);
2934        Ok(())
2935    }
2936
2937    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
2938    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
2939    ///
2940    /// # Notes
2941    ///
2942    /// This method is not intended to be called in analyses but rather in writing methods
2943    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
2944    pub fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
2945        let resources = self.resources.read();
2946        let parameters = resources.parameter_map.assemble(parameters)?;
2947        let amplitude_len = self.amplitude_use_sites.len();
2948        let active_indices = resources.active_indices().to_vec();
2949        let slot_count = self.expression_value_slot_count();
2950        let program_snapshot = self.expression_value_program_snapshot();
2951        #[cfg(feature = "rayon")]
2952        {
2953            Ok(resources
2954                .caches
2955                .par_iter()
2956                .map_init(
2957                    || {
2958                        (
2959                            vec![Complex64::ZERO; amplitude_len],
2960                            vec![Complex64::ZERO; slot_count],
2961                        )
2962                    },
2963                    |(amplitude_values, expr_slots), cache| {
2964                        self.fill_amplitude_values(
2965                            amplitude_values,
2966                            &active_indices,
2967                            &parameters,
2968                            cache,
2969                        );
2970                        self.evaluate_expression_value_with_program_snapshot(
2971                            &program_snapshot,
2972                            amplitude_values,
2973                            expr_slots,
2974                        )
2975                    },
2976                )
2977                .collect())
2978        }
2979        #[cfg(not(feature = "rayon"))]
2980        {
2981            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2982            let mut expr_slots = vec![Complex64::ZERO; slot_count];
2983            Ok(resources
2984                .caches
2985                .iter()
2986                .map(|cache| {
2987                    self.fill_amplitude_values(
2988                        &mut amplitude_values,
2989                        &active_indices,
2990                        &parameters,
2991                        cache,
2992                    );
2993                    self.evaluate_expression_value_with_program_snapshot(
2994                        &program_snapshot,
2995                        &amplitude_values,
2996                        &mut expr_slots,
2997                    )
2998                })
2999                .collect())
3000        }
3001    }
3002
3003    /// Evaluate local events using an explicit active-amplitude mask without mutating evaluator state.
3004    pub fn evaluate_local_with_active_mask(
3005        &self,
3006        parameters: &[f64],
3007        active_mask: &[bool],
3008    ) -> LadduResult<Vec<Complex64>> {
3009        let resources = self.resources.read();
3010        if active_mask.len() != resources.active.len() {
3011            return Err(LadduError::LengthMismatch {
3012                context: "active amplitude mask".to_string(),
3013                expected: resources.active.len(),
3014                actual: active_mask.len(),
3015            });
3016        }
3017        let parameters = resources.parameter_map.assemble(parameters)?;
3018        let amplitude_len = self.amplitude_use_sites.len();
3019        let active_indices = active_mask
3020            .iter()
3021            .enumerate()
3022            .filter_map(|(index, &active)| if active { Some(index) } else { None })
3023            .collect::<Vec<_>>();
3024        let program_snapshot =
3025            self.expression_value_program_snapshot_for_active_mask(active_mask)?;
3026        let slot_count = self.expression_value_program_snapshot_slot_count(&program_snapshot);
3027        #[cfg(feature = "rayon")]
3028        {
3029            Ok(resources
3030                .caches
3031                .par_iter()
3032                .map_init(
3033                    || {
3034                        (
3035                            vec![Complex64::ZERO; amplitude_len],
3036                            vec![Complex64::ZERO; slot_count],
3037                        )
3038                    },
3039                    |(amplitude_values, expr_slots), cache| {
3040                        self.fill_amplitude_values(
3041                            amplitude_values,
3042                            &active_indices,
3043                            &parameters,
3044                            cache,
3045                        );
3046                        self.evaluate_expression_value_with_program_snapshot(
3047                            &program_snapshot,
3048                            amplitude_values,
3049                            expr_slots,
3050                        )
3051                    },
3052                )
3053                .collect())
3054        }
3055        #[cfg(not(feature = "rayon"))]
3056        {
3057            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3058            let mut expr_slots = vec![Complex64::ZERO; slot_count];
3059            Ok(resources
3060                .caches
3061                .iter()
3062                .map(|cache| {
3063                    self.fill_amplitude_values(
3064                        &mut amplitude_values,
3065                        &active_indices,
3066                        &parameters,
3067                        cache,
3068                    );
3069                    self.evaluate_expression_value_with_program_snapshot(
3070                        &program_snapshot,
3071                        &amplitude_values,
3072                        &mut expr_slots,
3073                    )
3074                })
3075                .collect())
3076        }
3077    }
3078
3079    /// Evaluate the stored expression over local events using a reusable execution context.
3080    #[cfg(feature = "execution-context-prototype")]
3081    pub fn evaluate_local_with_ctx(
3082        &self,
3083        parameters: &[f64],
3084        execution_context: &ExecutionContext,
3085    ) -> Vec<Complex64> {
3086        let resources = self.resources.read();
3087        let parameters = resources
3088            .parameter_map
3089            .assemble(parameters)
3090            .expect("parameter slice must match evaluator resources");
3091        let amplitude_len = self.amplitude_use_sites.len();
3092        let active_indices = resources.active_indices().to_vec();
3093        let slot_count = self.expression_value_slot_count();
3094        let program_snapshot = self.expression_value_program_snapshot();
3095        #[cfg(feature = "rayon")]
3096        {
3097            if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3098                return execution_context.install(|| {
3099                    resources
3100                        .caches
3101                        .par_iter()
3102                        .map_init(
3103                            || {
3104                                (
3105                                    vec![Complex64::ZERO; amplitude_len],
3106                                    vec![Complex64::ZERO; slot_count],
3107                                )
3108                            },
3109                            |(amplitude_values, expr_slots), cache| {
3110                                self.fill_amplitude_values(
3111                                    amplitude_values,
3112                                    &active_indices,
3113                                    &parameters,
3114                                    cache,
3115                                );
3116                                self.evaluate_expression_value_with_program_snapshot(
3117                                    &program_snapshot,
3118                                    amplitude_values,
3119                                    expr_slots,
3120                                )
3121                            },
3122                        )
3123                        .collect()
3124                });
3125            }
3126        }
3127        execution_context.with_scratch(|scratch| {
3128            let (amplitude_values, expr_slots) =
3129                scratch.reserve_value_workspaces(amplitude_len, slot_count);
3130            resources
3131                .caches
3132                .iter()
3133                .map(|cache| {
3134                    self.fill_amplitude_values(
3135                        amplitude_values,
3136                        &active_indices,
3137                        &parameters,
3138                        cache,
3139                    );
3140                    self.evaluate_expression_value_with_program_snapshot(
3141                        &program_snapshot,
3142                        amplitude_values,
3143                        expr_slots,
3144                    )
3145                })
3146                .collect()
3147        })
3148    }
3149
3150    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
3151    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
3152    ///
3153    /// # Notes
3154    ///
3155    /// This method is not intended to be called in analyses but rather in writing methods
3156    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
3157    #[cfg(feature = "mpi")]
3158    fn evaluate_mpi(
3159        &self,
3160        parameters: &[f64],
3161        world: &SimpleCommunicator,
3162    ) -> LadduResult<Vec<Complex64>> {
3163        let local_evaluation = self.evaluate_local(parameters)?;
3164        let n_events = self.dataset.n_events();
3165        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3166        let (counts, displs) = world.get_counts_displs(n_events);
3167        {
3168            // NOTE: gather is required here because the public MPI API returns full per-event outputs.
3169            // Do not replace with all-reduce unless semantics change to scalar aggregates only.
3170            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3171            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3172        }
3173        Ok(buffer)
3174    }
3175
3176    #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3177    fn evaluate_mpi_with_ctx(
3178        &self,
3179        parameters: &[f64],
3180        world: &SimpleCommunicator,
3181        execution_context: &ExecutionContext,
3182    ) -> Vec<Complex64> {
3183        let local_evaluation = self.evaluate_local_with_ctx(parameters, execution_context);
3184        let n_events = self.dataset.n_events();
3185        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3186        let (counts, displs) = world.get_counts_displs(n_events);
3187        {
3188            // NOTE: gather is required here because the public MPI API returns full per-event outputs.
3189            // Do not replace with all-reduce unless semantics change to scalar aggregates only.
3190            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3191            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3192        }
3193        buffer
3194    }
3195
3196    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
3197    /// [`Evaluator`] with the given values for free parameters.
3198    pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
3199        #[cfg(feature = "mpi")]
3200        {
3201            if let Some(world) = crate::mpi::get_world() {
3202                return self.evaluate_mpi(parameters, &world);
3203            }
3204        }
3205        self.evaluate_local(parameters)
3206    }
3207
3208    /// Evaluate the stored expression with a reusable execution context.
3209    ///
3210    /// This is intended for repeated calls with the same context instance.
3211    /// Thread behavior follows [`ThreadPolicy`](crate::ThreadPolicy) configured on
3212    /// [`ExecutionContext`](crate::ExecutionContext).
3213    #[cfg(feature = "execution-context-prototype")]
3214    pub fn evaluate_with_ctx(
3215        &self,
3216        parameters: &[f64],
3217        execution_context: &ExecutionContext,
3218    ) -> Vec<Complex64> {
3219        #[cfg(feature = "mpi")]
3220        {
3221            if let Some(world) = crate::mpi::get_world() {
3222                return self.evaluate_mpi_with_ctx(parameters, &world, execution_context);
3223            }
3224        }
3225        self.evaluate_local_with_ctx(parameters, execution_context)
3226    }
3227
3228    /// See [`Evaluator::evaluate_local`]. This method evaluates over a subset of events rather
3229    /// than all events in the total dataset.
3230    pub fn evaluate_batch_local(
3231        &self,
3232        parameters: &[f64],
3233        indices: &[usize],
3234    ) -> LadduResult<Vec<Complex64>> {
3235        let resources = self.resources.read();
3236        let parameters = resources.parameter_map.assemble(parameters)?;
3237        let amplitude_len = self.amplitude_use_sites.len();
3238        let active_indices = resources.active_indices().to_vec();
3239        let slot_count = self.expression_value_slot_count();
3240        let program_snapshot = self.expression_value_program_snapshot();
3241        #[cfg(feature = "rayon")]
3242        {
3243            Ok(indices
3244                .par_iter()
3245                .map_init(
3246                    || {
3247                        (
3248                            vec![Complex64::ZERO; amplitude_len],
3249                            vec![Complex64::ZERO; slot_count],
3250                        )
3251                    },
3252                    |(amplitude_values, expr_slots), &idx| {
3253                        let cache = &resources.caches[idx];
3254                        self.fill_amplitude_values(
3255                            amplitude_values,
3256                            &active_indices,
3257                            &parameters,
3258                            cache,
3259                        );
3260                        self.evaluate_expression_value_with_program_snapshot(
3261                            &program_snapshot,
3262                            amplitude_values,
3263                            expr_slots,
3264                        )
3265                    },
3266                )
3267                .collect())
3268        }
3269        #[cfg(not(feature = "rayon"))]
3270        {
3271            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3272            let mut expr_slots = vec![Complex64::ZERO; slot_count];
3273            Ok(indices
3274                .iter()
3275                .map(|&idx| {
3276                    let cache = &resources.caches[idx];
3277                    self.fill_amplitude_values(
3278                        &mut amplitude_values,
3279                        &active_indices,
3280                        &parameters,
3281                        cache,
3282                    );
3283                    self.evaluate_expression_value_with_program_snapshot(
3284                        &program_snapshot,
3285                        &amplitude_values,
3286                        &mut expr_slots,
3287                    )
3288                })
3289                .collect())
3290        }
3291    }
3292
3293    /// See [`Evaluator::evaluate_mpi`]. This method evaluates over a subset of events rather
3294    /// than all events in the total dataset.
3295    #[cfg(feature = "mpi")]
3296    fn evaluate_batch_mpi(
3297        &self,
3298        parameters: &[f64],
3299        indices: &[usize],
3300        world: &SimpleCommunicator,
3301    ) -> LadduResult<Vec<Complex64>> {
3302        let total = self.dataset.n_events();
3303        let locals = world.locals_from_globals(indices, total);
3304        let local_evaluation = self.evaluate_batch_local(parameters, &locals)?;
3305        Ok(world.all_gather_batched_partitioned(&local_evaluation, indices, total, None))
3306    }
3307
3308    /// Evaluate the stored [`Expression`] over a subset of events in the [`Dataset`] stored by the
3309    /// [`Evaluator`] with the given values for free parameters. See also [`Evaluator::evaluate`].
3310    pub fn evaluate_batch(
3311        &self,
3312        parameters: &[f64],
3313        indices: &[usize],
3314    ) -> LadduResult<Vec<Complex64>> {
3315        #[cfg(feature = "mpi")]
3316        {
3317            if let Some(world) = crate::mpi::get_world() {
3318                return self.evaluate_batch_mpi(parameters, indices, &world);
3319            }
3320        }
3321        self.evaluate_batch_local(parameters, indices)
3322    }
3323
3324    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
3325    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
3326    ///
3327    /// # Notes
3328    ///
3329    /// This method is not intended to be called in analyses but rather in writing methods
3330    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
3331    pub fn evaluate_gradient_local(
3332        &self,
3333        parameters: &[f64],
3334    ) -> LadduResult<Vec<DVector<Complex64>>> {
3335        let resources = self.resources.read();
3336        let parameters = resources.parameter_map.assemble(parameters)?;
3337        let amplitude_len = self.amplitude_use_sites.len();
3338        let grad_dim = parameters.len();
3339        let active_indices = resources.active_indices().to_vec();
3340        let lowered_runtime = self.lowered_runtime();
3341        let gradient_program = lowered_runtime.gradient_program();
3342        let slot_count = self.expression_gradient_slot_count();
3343        #[cfg(feature = "rayon")]
3344        {
3345            Ok(resources
3346                .caches
3347                .par_iter()
3348                .map_init(
3349                    || {
3350                        (
3351                            vec![Complex64::ZERO; amplitude_len],
3352                            vec![DVector::zeros(grad_dim); amplitude_len],
3353                            vec![Complex64::ZERO; slot_count],
3354                            vec![Complex64::ZERO; slot_count * grad_dim],
3355                        )
3356                    },
3357                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3358                        self.fill_amplitude_values_and_gradients(
3359                            amplitude_values,
3360                            gradient_values,
3361                            &active_indices,
3362                            &resources.active,
3363                            &parameters,
3364                            cache,
3365                        );
3366                        gradient_program.evaluate_gradient_into_flat(
3367                            amplitude_values,
3368                            gradient_values,
3369                            value_slots,
3370                            gradient_slots,
3371                            grad_dim,
3372                        )
3373                    },
3374                )
3375                .collect())
3376        }
3377        #[cfg(not(feature = "rayon"))]
3378        {
3379            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3380            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3381            let mut value_slots = vec![Complex64::ZERO; slot_count];
3382            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3383            Ok(resources
3384                .caches
3385                .iter()
3386                .map(|cache| {
3387                    self.fill_amplitude_values_and_gradients(
3388                        &mut amplitude_values,
3389                        &mut gradient_values,
3390                        &active_indices,
3391                        &resources.active,
3392                        &parameters,
3393                        cache,
3394                    );
3395                    gradient_program.evaluate_gradient_into_flat(
3396                        &amplitude_values,
3397                        &gradient_values,
3398                        &mut value_slots,
3399                        &mut gradient_slots,
3400                        grad_dim,
3401                    )
3402                })
3403                .collect())
3404        }
3405    }
3406
3407    /// Evaluate the gradient over local events using a reusable execution context.
3408    #[cfg(feature = "execution-context-prototype")]
3409    pub fn evaluate_gradient_local_with_ctx(
3410        &self,
3411        parameters: &[f64],
3412        execution_context: &ExecutionContext,
3413    ) -> Vec<DVector<Complex64>> {
3414        let resources = self.resources.read();
3415        let parameters = resources
3416            .parameter_map
3417            .assemble(parameters)
3418            .expect("parameter slice must match evaluator resources");
3419        let amplitude_len = self.amplitude_use_sites.len();
3420        let grad_dim = parameters.len();
3421        let active_indices = resources.active_indices().to_vec();
3422        let slot_count = self.expression_slot_count();
3423        #[cfg(feature = "rayon")]
3424        {
3425            if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3426                return execution_context.install(|| {
3427                    resources
3428                        .caches
3429                        .par_iter()
3430                        .map_init(
3431                            || {
3432                                (
3433                                    vec![Complex64::ZERO; amplitude_len],
3434                                    vec![DVector::zeros(grad_dim); amplitude_len],
3435                                    vec![Complex64::ZERO; slot_count],
3436                                    vec![DVector::zeros(grad_dim); slot_count],
3437                                )
3438                            },
3439                            |(amplitude_values, gradient_values, value_slots, gradient_slots),
3440                             cache| {
3441                                self.evaluate_cache_gradient_with_scratch(
3442                                    amplitude_values,
3443                                    gradient_values,
3444                                    value_slots,
3445                                    gradient_slots,
3446                                    &active_indices,
3447                                    &resources.active,
3448                                    &parameters,
3449                                    cache,
3450                                )
3451                            },
3452                        )
3453                        .collect()
3454                });
3455            }
3456        }
3457        execution_context.with_scratch(|scratch| {
3458            let (amplitude_values, value_slots, gradient_values, gradient_slots) =
3459                scratch.reserve_gradient_workspaces(amplitude_len, slot_count, grad_dim);
3460            resources
3461                .caches
3462                .iter()
3463                .map(|cache| {
3464                    self.evaluate_cache_gradient_with_scratch(
3465                        amplitude_values,
3466                        gradient_values,
3467                        value_slots,
3468                        gradient_slots,
3469                        &active_indices,
3470                        &resources.active,
3471                        &parameters,
3472                        cache,
3473                    )
3474                })
3475                .collect()
3476        })
3477    }
3478
3479    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
3480    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
3481    ///
3482    /// # Notes
3483    ///
3484    /// This method is not intended to be called in analyses but rather in writing methods
3485    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
3486    #[cfg(feature = "mpi")]
3487    fn evaluate_gradient_mpi(
3488        &self,
3489        parameters: &[f64],
3490        world: &SimpleCommunicator,
3491    ) -> LadduResult<Vec<DVector<Complex64>>> {
3492        let local_evaluation = self.evaluate_gradient_local(parameters)?;
3493        let n_events = self.dataset.n_events();
3494        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
3495        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
3496        {
3497            // NOTE: gather is required here because the public MPI API returns full per-event gradients.
3498            // Do not replace with all-reduce unless semantics change to aggregate-only outputs.
3499            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3500            world.all_gather_varcount_into(
3501                &local_evaluation
3502                    .iter()
3503                    .flat_map(|v| v.data.as_vec())
3504                    .copied()
3505                    .collect::<Vec<_>>(),
3506                &mut partitioned_buffer,
3507            );
3508        }
3509        Ok(buffer
3510            .chunks(parameters.len())
3511            .map(DVector::from_row_slice)
3512            .collect())
3513    }
3514
3515    #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3516    fn evaluate_gradient_mpi_with_ctx(
3517        &self,
3518        parameters: &[f64],
3519        world: &SimpleCommunicator,
3520        execution_context: &ExecutionContext,
3521    ) -> Vec<DVector<Complex64>> {
3522        let local_evaluation = self.evaluate_gradient_local_with_ctx(parameters, execution_context);
3523        let n_events = self.dataset.n_events();
3524        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
3525        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
3526        {
3527            // NOTE: gather is required here because the public MPI API returns full per-event gradients.
3528            // Do not replace with all-reduce unless semantics change to aggregate-only outputs.
3529            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3530            world.all_gather_varcount_into(
3531                &local_evaluation
3532                    .iter()
3533                    .flat_map(|v| v.data.as_vec())
3534                    .copied()
3535                    .collect::<Vec<_>>(),
3536                &mut partitioned_buffer,
3537            );
3538        }
3539        buffer
3540            .chunks(parameters.len())
3541            .map(DVector::from_row_slice)
3542            .collect()
3543    }
3544
3545    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
3546    /// [`Evaluator`] with the given values for free parameters.
3547    pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<Vec<DVector<Complex64>>> {
3548        #[cfg(feature = "mpi")]
3549        {
3550            if let Some(world) = crate::mpi::get_world() {
3551                return self.evaluate_gradient_mpi(parameters, &world);
3552            }
3553        }
3554        self.evaluate_gradient_local(parameters)
3555    }
3556
3557    /// Evaluate the gradient with a reusable execution context.
3558    ///
3559    /// This is intended for repeated calls with the same context instance.
3560    /// Thread behavior follows [`ThreadPolicy`](crate::ThreadPolicy) configured on
3561    /// [`ExecutionContext`](crate::ExecutionContext).
3562    #[cfg(feature = "execution-context-prototype")]
3563    pub fn evaluate_gradient_with_ctx(
3564        &self,
3565        parameters: &[f64],
3566        execution_context: &ExecutionContext,
3567    ) -> Vec<DVector<Complex64>> {
3568        #[cfg(feature = "mpi")]
3569        {
3570            if let Some(world) = crate::mpi::get_world() {
3571                return self.evaluate_gradient_mpi_with_ctx(parameters, &world, execution_context);
3572            }
3573        }
3574        self.evaluate_gradient_local_with_ctx(parameters, execution_context)
3575    }
3576
3577    /// See [`Evaluator::evaluate_gradient_local`]. This method evaluates over a subset
3578    /// of events rather than all events in the total dataset.
3579    pub fn evaluate_gradient_batch_local(
3580        &self,
3581        parameters: &[f64],
3582        indices: &[usize],
3583    ) -> LadduResult<Vec<DVector<Complex64>>> {
3584        let resources = self.resources.read();
3585        let parameters = resources.parameter_map.assemble(parameters)?;
3586        let amplitude_len = self.amplitude_use_sites.len();
3587        let grad_dim = parameters.len();
3588        let active_indices = resources.active_indices().to_vec();
3589        let lowered_runtime = self.lowered_runtime();
3590        let gradient_program = lowered_runtime.gradient_program();
3591        let slot_count = self.expression_gradient_slot_count();
3592        #[cfg(feature = "rayon")]
3593        {
3594            Ok(indices
3595                .par_iter()
3596                .map_init(
3597                    || {
3598                        (
3599                            vec![Complex64::ZERO; amplitude_len],
3600                            vec![DVector::zeros(grad_dim); amplitude_len],
3601                            vec![Complex64::ZERO; slot_count],
3602                            vec![Complex64::ZERO; slot_count * grad_dim],
3603                        )
3604                    },
3605                    |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
3606                        let cache = &resources.caches[idx];
3607                        self.fill_amplitude_values_and_gradients(
3608                            amplitude_values,
3609                            gradient_values,
3610                            &active_indices,
3611                            &resources.active,
3612                            &parameters,
3613                            cache,
3614                        );
3615                        gradient_program.evaluate_gradient_into_flat(
3616                            amplitude_values,
3617                            gradient_values,
3618                            value_slots,
3619                            gradient_slots,
3620                            grad_dim,
3621                        )
3622                    },
3623                )
3624                .collect())
3625        }
3626        #[cfg(not(feature = "rayon"))]
3627        {
3628            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3629            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3630            let mut value_slots = vec![Complex64::ZERO; slot_count];
3631            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3632            Ok(indices
3633                .iter()
3634                .map(|&idx| {
3635                    let cache = &resources.caches[idx];
3636                    self.fill_amplitude_values_and_gradients(
3637                        &mut amplitude_values,
3638                        &mut gradient_values,
3639                        &active_indices,
3640                        &resources.active,
3641                        &parameters,
3642                        cache,
3643                    );
3644                    gradient_program.evaluate_gradient_into_flat(
3645                        &amplitude_values,
3646                        &gradient_values,
3647                        &mut value_slots,
3648                        &mut gradient_slots,
3649                        grad_dim,
3650                    )
3651                })
3652                .collect())
3653        }
3654    }
3655
3656    /// See [`Evaluator::evaluate_gradient_mpi`]. This method evaluates over a subset
3657    /// of events rather than all events in the total dataset.
3658    #[cfg(feature = "mpi")]
3659    fn evaluate_gradient_batch_mpi(
3660        &self,
3661        parameters: &[f64],
3662        indices: &[usize],
3663        world: &SimpleCommunicator,
3664    ) -> LadduResult<Vec<DVector<Complex64>>> {
3665        let total = self.dataset.n_events();
3666        let locals = world.locals_from_globals(indices, total);
3667        let flattened_local_evaluation = self
3668            .evaluate_gradient_batch_local(parameters, &locals)?
3669            .iter()
3670            .flat_map(|g| g.data.as_vec().to_vec())
3671            .collect::<Vec<Complex64>>();
3672        Ok(world
3673            .all_gather_batched_partitioned(
3674                &flattened_local_evaluation,
3675                indices,
3676                total,
3677                Some(parameters.len()),
3678            )
3679            .chunks(parameters.len())
3680            .map(DVector::from_row_slice)
3681            .collect())
3682    }
3683
3684    /// Evaluate the gradient of the stored [`Expression`] over a subset of the
3685    /// events in the [`Dataset`] stored by the [`Evaluator`] with the given values
3686    /// for free parameters. See also [`Evaluator::evaluate_gradient`].
3687    pub fn evaluate_gradient_batch(
3688        &self,
3689        parameters: &[f64],
3690        indices: &[usize],
3691    ) -> LadduResult<Vec<DVector<Complex64>>> {
3692        #[cfg(feature = "mpi")]
3693        {
3694            if let Some(world) = crate::mpi::get_world() {
3695                return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
3696            }
3697        }
3698        self.evaluate_gradient_batch_local(parameters, indices)
3699    }
3700
3701    /// Evaluate the stored expression and its gradient over local events in one fused pass.
3702    pub fn evaluate_with_gradient_local(
3703        &self,
3704        parameters: &[f64],
3705    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3706        let resources = self.resources.read();
3707        let parameters = resources.parameter_map.assemble(parameters)?;
3708        let amplitude_len = self.amplitude_use_sites.len();
3709        let grad_dim = parameters.len();
3710        let active_indices = resources.active_indices().to_vec();
3711        let lowered_runtime = self.lowered_runtime();
3712        let value_gradient_program = lowered_runtime.value_gradient_program();
3713        let slot_count = self.expression_value_gradient_slot_count();
3714        #[cfg(feature = "rayon")]
3715        {
3716            Ok(resources
3717                .caches
3718                .par_iter()
3719                .map_init(
3720                    || {
3721                        (
3722                            vec![Complex64::ZERO; amplitude_len],
3723                            vec![DVector::zeros(grad_dim); amplitude_len],
3724                            vec![Complex64::ZERO; slot_count],
3725                            vec![Complex64::ZERO; slot_count * grad_dim],
3726                        )
3727                    },
3728                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3729                        self.fill_amplitude_values_and_gradients(
3730                            amplitude_values,
3731                            gradient_values,
3732                            &active_indices,
3733                            &resources.active,
3734                            &parameters,
3735                            cache,
3736                        );
3737                        value_gradient_program.evaluate_value_gradient_into_flat(
3738                            amplitude_values,
3739                            gradient_values,
3740                            value_slots,
3741                            gradient_slots,
3742                            grad_dim,
3743                        )
3744                    },
3745                )
3746                .collect())
3747        }
3748        #[cfg(not(feature = "rayon"))]
3749        {
3750            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3751            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3752            let mut value_slots = vec![Complex64::ZERO; slot_count];
3753            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3754            Ok(resources
3755                .caches
3756                .iter()
3757                .map(|cache| {
3758                    self.fill_amplitude_values_and_gradients(
3759                        &mut amplitude_values,
3760                        &mut gradient_values,
3761                        &active_indices,
3762                        &resources.active,
3763                        &parameters,
3764                        cache,
3765                    );
3766                    value_gradient_program.evaluate_value_gradient_into_flat(
3767                        &amplitude_values,
3768                        &gradient_values,
3769                        &mut value_slots,
3770                        &mut gradient_slots,
3771                        grad_dim,
3772                    )
3773                })
3774                .collect())
3775        }
3776    }
3777
3778    /// Evaluate local events and gradients with an explicit active-amplitude mask without mutating evaluator state.
3779    pub fn evaluate_with_gradient_local_with_active_mask(
3780        &self,
3781        parameters: &[f64],
3782        active_mask: &[bool],
3783    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3784        let resources = self.resources.read();
3785        if active_mask.len() != resources.active.len() {
3786            return Err(LadduError::LengthMismatch {
3787                context: "active amplitude mask".to_string(),
3788                expected: resources.active.len(),
3789                actual: active_mask.len(),
3790            });
3791        }
3792        let parameters = resources.parameter_map.assemble(parameters)?;
3793        let amplitude_len = self.amplitude_use_sites.len();
3794        let grad_dim = parameters.len();
3795        let active_indices = active_mask
3796            .iter()
3797            .enumerate()
3798            .filter_map(|(index, &active)| if active { Some(index) } else { None })
3799            .collect::<Vec<_>>();
3800        let lowered_runtime = self.lower_expression_runtime_for_active_mask(active_mask)?;
3801        let slot_count = lowered_runtime.value_gradient_program().scratch_slots();
3802        #[cfg(feature = "rayon")]
3803        {
3804            Ok(resources
3805                .caches
3806                .par_iter()
3807                .map_init(
3808                    || {
3809                        (
3810                            vec![Complex64::ZERO; amplitude_len],
3811                            vec![DVector::zeros(grad_dim); amplitude_len],
3812                            vec![Complex64::ZERO; slot_count],
3813                            vec![Complex64::ZERO; slot_count * grad_dim],
3814                        )
3815                    },
3816                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3817                        self.fill_amplitude_values_and_gradients(
3818                            amplitude_values,
3819                            gradient_values,
3820                            &active_indices,
3821                            active_mask,
3822                            &parameters,
3823                            cache,
3824                        );
3825                        lowered_runtime
3826                            .value_gradient_program()
3827                            .evaluate_value_gradient_into_flat(
3828                                amplitude_values,
3829                                gradient_values,
3830                                value_slots,
3831                                gradient_slots,
3832                                grad_dim,
3833                            )
3834                    },
3835                )
3836                .collect())
3837        }
3838        #[cfg(not(feature = "rayon"))]
3839        {
3840            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3841            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3842            let mut value_slots = vec![Complex64::ZERO; slot_count];
3843            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3844            Ok(resources
3845                .caches
3846                .iter()
3847                .map(|cache| {
3848                    self.fill_amplitude_values_and_gradients(
3849                        &mut amplitude_values,
3850                        &mut gradient_values,
3851                        &active_indices,
3852                        active_mask,
3853                        &parameters,
3854                        cache,
3855                    );
3856                    lowered_runtime
3857                        .value_gradient_program()
3858                        .evaluate_value_gradient_into_flat(
3859                            &amplitude_values,
3860                            &gradient_values,
3861                            &mut value_slots,
3862                            &mut gradient_slots,
3863                            grad_dim,
3864                        )
3865                })
3866                .collect())
3867        }
3868    }
3869
3870    /// Evaluate the stored expression and its gradient over a local subset of events in one fused pass.
3871    pub fn evaluate_with_gradient_batch_local(
3872        &self,
3873        parameters: &[f64],
3874        indices: &[usize],
3875    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3876        let resources = self.resources.read();
3877        let parameters = resources.parameter_map.assemble(parameters)?;
3878        let amplitude_len = self.amplitude_use_sites.len();
3879        let grad_dim = parameters.len();
3880        let active_indices = resources.active_indices().to_vec();
3881        let lowered_runtime = self.lowered_runtime();
3882        let value_gradient_program = lowered_runtime.value_gradient_program();
3883        let slot_count = self.expression_value_gradient_slot_count();
3884        #[cfg(feature = "rayon")]
3885        {
3886            Ok(indices
3887                .par_iter()
3888                .map_init(
3889                    || {
3890                        (
3891                            vec![Complex64::ZERO; amplitude_len],
3892                            vec![DVector::zeros(grad_dim); amplitude_len],
3893                            vec![Complex64::ZERO; slot_count],
3894                            vec![Complex64::ZERO; slot_count * grad_dim],
3895                        )
3896                    },
3897                    |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
3898                        let cache = &resources.caches[idx];
3899                        self.fill_amplitude_values_and_gradients(
3900                            amplitude_values,
3901                            gradient_values,
3902                            &active_indices,
3903                            &resources.active,
3904                            &parameters,
3905                            cache,
3906                        );
3907                        value_gradient_program.evaluate_value_gradient_into_flat(
3908                            amplitude_values,
3909                            gradient_values,
3910                            value_slots,
3911                            gradient_slots,
3912                            grad_dim,
3913                        )
3914                    },
3915                )
3916                .collect())
3917        }
3918        #[cfg(not(feature = "rayon"))]
3919        {
3920            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3921            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3922            let mut value_slots = vec![Complex64::ZERO; slot_count];
3923            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3924            Ok(indices
3925                .iter()
3926                .map(|&idx| {
3927                    let cache = &resources.caches[idx];
3928                    self.fill_amplitude_values_and_gradients(
3929                        &mut amplitude_values,
3930                        &mut gradient_values,
3931                        &active_indices,
3932                        &resources.active,
3933                        &parameters,
3934                        cache,
3935                    );
3936                    value_gradient_program.evaluate_value_gradient_into_flat(
3937                        &amplitude_values,
3938                        &gradient_values,
3939                        &mut value_slots,
3940                        &mut gradient_slots,
3941                        grad_dim,
3942                    )
3943                })
3944                .collect())
3945        }
3946    }
3947}
3948
3949#[cfg(test)]
3950mod tests {
3951    use approx::assert_relative_eq;
3952    #[cfg(feature = "mpi")]
3953    use mpi_test::mpi_test;
3954    use serde::{Deserialize, Serialize};
3955
3956    use super::*;
3957    use crate::{
3958        amplitude::{AmplitudeID, Tags, TestAmplitude},
3959        data::{test_dataset, test_event, DatasetMetadata, Event, EventData},
3960        parameter,
3961        parameters::Parameter,
3962        resources::{Cache, ParameterID, Parameters, Resources, ScalarID},
3963        vectors::Vec4,
3964    };
3965
3966    #[derive(Clone, Serialize, Deserialize)]
3967    pub struct ComplexScalar {
3968        name: String,
3969        re: Parameter,
3970        pid_re: ParameterID,
3971        im: Parameter,
3972        pid_im: ParameterID,
3973    }
3974
3975    impl ComplexScalar {
3976        #[allow(clippy::new_ret_no_self)]
3977        pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
3978            Self {
3979                name: name.to_string(),
3980                re,
3981                pid_re: Default::default(),
3982                im,
3983                pid_im: Default::default(),
3984            }
3985            .into_expression()
3986        }
3987    }
3988
3989    #[typetag::serde]
3990    impl Amplitude for ComplexScalar {
3991        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
3992            self.pid_re = resources.register_parameter(&self.re)?;
3993            self.pid_im = resources.register_parameter(&self.im)?;
3994            resources.register_amplitude(&self.name)
3995        }
3996
3997        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
3998            Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
3999        }
4000
4001        fn compute_gradient(
4002            &self,
4003            parameters: &Parameters,
4004            _cache: &Cache,
4005            gradient: &mut DVector<Complex64>,
4006        ) {
4007            if let Some(ind) = parameters.free_index(self.pid_re) {
4008                gradient[ind] = Complex64::ONE;
4009            }
4010            if let Some(ind) = parameters.free_index(self.pid_im) {
4011                gradient[ind] = Complex64::I;
4012            }
4013        }
4014    }
4015
4016    #[derive(Clone, Serialize, Deserialize)]
4017    pub struct ParameterOnlyScalar {
4018        name: String,
4019        value: Parameter,
4020        pid: ParameterID,
4021    }
4022
4023    impl ParameterOnlyScalar {
4024        #[allow(clippy::new_ret_no_self)]
4025        pub fn new(name: &str, value: Parameter) -> LadduResult<Expression> {
4026            Self {
4027                name: name.to_string(),
4028                value,
4029                pid: Default::default(),
4030            }
4031            .into_expression()
4032        }
4033    }
4034
4035    #[typetag::serde]
4036    impl Amplitude for ParameterOnlyScalar {
4037        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4038            self.pid = resources.register_parameter(&self.value)?;
4039            resources.register_amplitude(&self.name)
4040        }
4041
4042        fn dependence_hint(&self) -> ExpressionDependence {
4043            ExpressionDependence::ParameterOnly
4044        }
4045
4046        fn real_valued_hint(&self) -> bool {
4047            true
4048        }
4049
4050        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4051            Complex64::new(parameters.get(self.pid), 0.0)
4052        }
4053    }
4054
4055    #[derive(Clone, Serialize, Deserialize)]
4056    pub struct CacheOnlyScalar {
4057        name: String,
4058        beam_energy: ScalarID,
4059    }
4060
4061    impl CacheOnlyScalar {
4062        #[allow(clippy::new_ret_no_self)]
4063        pub fn new(name: &str) -> LadduResult<Expression> {
4064            Self {
4065                name: name.to_string(),
4066                beam_energy: Default::default(),
4067            }
4068            .into_expression()
4069        }
4070    }
4071
4072    #[typetag::serde]
4073    impl Amplitude for CacheOnlyScalar {
4074        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4075            self.beam_energy =
4076                resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
4077            resources.register_amplitude(&self.name)
4078        }
4079
4080        fn dependence_hint(&self) -> ExpressionDependence {
4081            ExpressionDependence::CacheOnly
4082        }
4083
4084        fn real_valued_hint(&self) -> bool {
4085            true
4086        }
4087
4088        fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {
4089            cache.store_scalar(self.beam_energy, event.p4_at(0).e());
4090        }
4091
4092        fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
4093            Complex64::new(cache.get_scalar(self.beam_energy), 0.0)
4094        }
4095    }
4096
4097    #[derive(Clone, Copy)]
4098    enum DeterministicFixtureKind {
4099        Separable,
4100        Partial,
4101        NonSeparable,
4102    }
4103
4104    struct DeterministicFixture {
4105        expression: Expression,
4106        dataset: Arc<Dataset>,
4107        parameters: Vec<f64>,
4108    }
4109
4110    const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
4111    const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
4112
4113    fn deterministic_fixture_dataset() -> Arc<Dataset> {
4114        let metadata = Arc::new(DatasetMetadata::default());
4115        let events = vec![
4116            Arc::new(EventData {
4117                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
4118                aux: vec![],
4119                weight: 0.5,
4120            }),
4121            Arc::new(EventData {
4122                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
4123                aux: vec![],
4124                weight: -1.25,
4125            }),
4126            Arc::new(EventData {
4127                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
4128                aux: vec![],
4129                weight: 2.0,
4130            }),
4131            Arc::new(EventData {
4132                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
4133                aux: vec![],
4134                weight: 0.75,
4135            }),
4136        ];
4137        Arc::new(Dataset::new_with_metadata(events, metadata))
4138    }
4139
4140    fn make_deterministic_fixture(kind: DeterministicFixtureKind) -> DeterministicFixture {
4141        let dataset = deterministic_fixture_dataset();
4142        match kind {
4143            DeterministicFixtureKind::Separable => {
4144                let p1 = ParameterOnlyScalar::new("p1", parameter!("p1"))
4145                    .expect("separable p1 should build");
4146                let p2 = ParameterOnlyScalar::new("p2", parameter!("p2"))
4147                    .expect("separable p2 should build");
4148                let c1 = CacheOnlyScalar::new("c1").expect("separable c1 should build");
4149                let c2 = CacheOnlyScalar::new("c2").expect("separable c2 should build");
4150                DeterministicFixture {
4151                    expression: (&p1 * &c1) + &(&p2 * &c2),
4152                    dataset,
4153                    parameters: vec![0.4, -0.3],
4154                }
4155            }
4156            DeterministicFixtureKind::Partial => {
4157                let p =
4158                    ParameterOnlyScalar::new("p", parameter!("p")).expect("partial p should build");
4159                let c = CacheOnlyScalar::new("c").expect("partial c should build");
4160                let m = TestAmplitude::new("m", parameter!("mr"), parameter!("mi"))
4161                    .expect("partial m should build");
4162                DeterministicFixture {
4163                    expression: (&p * &c) + &m,
4164                    dataset,
4165                    parameters: vec![0.55, 0.2, -0.15],
4166                }
4167            }
4168            DeterministicFixtureKind::NonSeparable => {
4169                let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i"))
4170                    .expect("non-separable m1 should build");
4171                let m2 = TestAmplitude::new("m2", parameter!("m2r"), parameter!("m2i"))
4172                    .expect("non-separable m2 should build");
4173                DeterministicFixture {
4174                    expression: &m1 * &m2,
4175                    dataset,
4176                    parameters: vec![0.25, -0.4, 0.6, 0.1],
4177                }
4178            }
4179        }
4180    }
4181
4182    fn assert_weighted_sum_matches_eventwise_baseline(fixture: &DeterministicFixture) {
4183        let evaluator = fixture
4184            .expression
4185            .load(&fixture.dataset)
4186            .expect("fixture evaluator should load");
4187        let expected_value = evaluator
4188            .evaluate_local(&fixture.parameters)
4189            .expect("evaluation should succeed")
4190            .iter()
4191            .zip(fixture.dataset.weights_local().iter())
4192            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
4193        let expected_gradient = evaluator
4194            .evaluate_gradient_local(&fixture.parameters)
4195            .expect("evaluation should succeed")
4196            .iter()
4197            .zip(fixture.dataset.weights_local().iter())
4198            .fold(
4199                DVector::zeros(fixture.parameters.len()),
4200                |mut accum, (gradient, event)| {
4201                    accum += gradient.map(|value| value.re).scale(*event);
4202                    accum
4203                },
4204            );
4205        let actual_value = evaluator
4206            .evaluate_weighted_value_sum_local(&fixture.parameters)
4207            .expect("evaluation should succeed");
4208        let actual_gradient = evaluator
4209            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4210            .expect("evaluation should succeed");
4211        assert_relative_eq!(
4212            actual_value,
4213            expected_value,
4214            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4215            max_relative = DETERMINISTIC_STRICT_REL_TOL
4216        );
4217        for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
4218            assert_relative_eq!(
4219                *actual_item,
4220                *expected_item,
4221                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4222                max_relative = DETERMINISTIC_STRICT_REL_TOL
4223            );
4224        }
4225    }
4226    fn assert_mixed_normalization_components_match_combined_path(fixture: &DeterministicFixture) {
4227        let evaluator = fixture
4228            .expression
4229            .load(&fixture.dataset)
4230            .expect("fixture evaluator should load");
4231        let state = {
4232            let resources = evaluator.resources.read();
4233            evaluator.ensure_cached_integral_cache_state(&resources)
4234        }
4235        .expect("state should be available");
4236        assert!(
4237            !state.values.is_empty(),
4238            "fixture should exercise cached normalization terms"
4239        );
4240        assert!(
4241            !state.execution_sets.residual_amplitudes.is_empty(),
4242            "fixture should exercise residual normalization amplitudes"
4243        );
4244
4245        let (residual_value_sum, cached_value_sum) = evaluator
4246            .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4247            .expect("evaluation should succeed");
4248        assert!(residual_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4249        assert!(cached_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4250        let combined_value = evaluator
4251            .evaluate_weighted_value_sum_local(&fixture.parameters)
4252            .expect("evaluation should succeed");
4253        assert_relative_eq!(
4254            residual_value_sum + cached_value_sum,
4255            combined_value,
4256            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4257            max_relative = DETERMINISTIC_STRICT_REL_TOL
4258        );
4259
4260        let (residual_gradient_sum, cached_gradient_sum) = evaluator
4261            .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4262            .expect("evaluation should succeed");
4263        let combined_gradient = evaluator
4264            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4265            .expect("evaluation should succeed");
4266        assert!(residual_gradient_sum
4267            .iter()
4268            .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4269        assert!(cached_gradient_sum
4270            .iter()
4271            .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4272        for ((residual_item, cached_item), combined_item) in residual_gradient_sum
4273            .iter()
4274            .zip(cached_gradient_sum.iter())
4275            .zip(combined_gradient.iter())
4276        {
4277            assert_relative_eq!(
4278                residual_item + cached_item,
4279                *combined_item,
4280                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4281                max_relative = DETERMINISTIC_STRICT_REL_TOL
4282            );
4283        }
4284    }
4285
4286    #[test]
4287    fn test_deterministic_fixture_weighted_sums_stable_across_activation_mask_toggle() {
4288        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4289        let evaluator = fixture
4290            .expression
4291            .load(&fixture.dataset)
4292            .expect("fixture evaluator should load");
4293        let original_mask = evaluator.active_mask();
4294
4295        let original_value = evaluator
4296            .evaluate_weighted_value_sum_local(&fixture.parameters)
4297            .expect("evaluation should succeed");
4298
4299        evaluator.isolate_many(&["p", "c"]);
4300        assert_ne!(evaluator.active_mask(), original_mask);
4301
4302        evaluator
4303            .set_active_mask(&original_mask)
4304            .expect("original fixture active mask should restore");
4305        assert_eq!(evaluator.active_mask(), original_mask);
4306        let actual_value = evaluator
4307            .evaluate_weighted_value_sum_local(&fixture.parameters)
4308            .expect("evaluation should succeed");
4309        assert_relative_eq!(
4310            actual_value,
4311            original_value,
4312            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4313            max_relative = DETERMINISTIC_STRICT_REL_TOL
4314        );
4315    }
4316
4317    #[test]
4318    fn test_deterministic_fixtures_match_eventwise_weighted_sums() {
4319        let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4320        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4321        let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4322
4323        assert_weighted_sum_matches_eventwise_baseline(&separable);
4324        assert_weighted_sum_matches_eventwise_baseline(&partial);
4325        assert_weighted_sum_matches_eventwise_baseline(&non_separable);
4326    }
4327    #[test]
4328    fn test_deterministic_fixtures_cover_separable_partial_non_separable_models() {
4329        let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4330        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4331        let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4332
4333        let separable_evaluator = separable
4334            .expression
4335            .load(&separable.dataset)
4336            .expect("separable evaluator should load");
4337        let partial_evaluator = partial
4338            .expression
4339            .load(&partial.dataset)
4340            .expect("partial evaluator should load");
4341        let non_separable_evaluator = non_separable
4342            .expression
4343            .load(&non_separable.dataset)
4344            .expect("non-separable evaluator should load");
4345
4346        assert_eq!(
4347            separable_evaluator
4348                .expression_precomputed_cached_integrals()
4349                .expect("integrals should be computed")
4350                .len(),
4351            2
4352        );
4353        assert_eq!(
4354            partial_evaluator
4355                .expression_precomputed_cached_integrals()
4356                .expect("integrals should be computed")
4357                .len(),
4358            1
4359        );
4360        assert!(non_separable_evaluator
4361            .expression_precomputed_cached_integrals()
4362            .expect("integrals should be computed")
4363            .is_empty());
4364    }
4365    #[test]
4366    fn test_partial_fixture_combined_normalization_components_match_total() {
4367        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4368        assert_mixed_normalization_components_match_combined_path(&partial);
4369    }
4370    #[test]
4371    fn test_non_separable_fixture_normalization_components_stay_residual_only() {
4372        let fixture = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4373        let evaluator = fixture
4374            .expression
4375            .load(&fixture.dataset)
4376            .expect("fixture evaluator should load");
4377        let resources = evaluator.resources.read();
4378        let state = evaluator
4379            .ensure_cached_integral_cache_state(&resources)
4380            .expect("state should be available");
4381        assert!(state.values.is_empty());
4382
4383        let (residual_value_sum, cached_value_sum) = evaluator
4384            .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4385            .expect("evaluation should succeed");
4386        assert_relative_eq!(
4387            cached_value_sum,
4388            0.0,
4389            epsilon = DETERMINISTIC_STRICT_ABS_TOL
4390        );
4391        assert_relative_eq!(
4392            residual_value_sum,
4393            evaluator
4394                .evaluate_weighted_value_sum_local(&fixture.parameters)
4395                .expect("evaluation should succeed"),
4396            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4397            max_relative = DETERMINISTIC_STRICT_REL_TOL
4398        );
4399
4400        let (residual_gradient_sum, cached_gradient_sum) = evaluator
4401            .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4402            .expect("evaluation should succeed");
4403        assert!(cached_gradient_sum
4404            .iter()
4405            .all(|value| value.abs() <= DETERMINISTIC_STRICT_ABS_TOL));
4406        let combined_gradient = evaluator
4407            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4408            .expect("evaluation should succeed");
4409        for (residual_item, combined_item) in
4410            residual_gradient_sum.iter().zip(combined_gradient.iter())
4411        {
4412            assert_relative_eq!(
4413                *residual_item,
4414                *combined_item,
4415                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4416                max_relative = DETERMINISTIC_STRICT_REL_TOL
4417            );
4418        }
4419    }
4420
4421    #[test]
4422    fn test_batch_evaluation() {
4423        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag")).unwrap();
4424        let mut event1 = test_event();
4425        event1.p4s[0].t = 10.0;
4426        let mut event2 = test_event();
4427        event2.p4s[0].t = 11.0;
4428        let mut event3 = test_event();
4429        event3.p4s[0].t = 12.0;
4430        let dataset = Arc::new(Dataset::new_with_metadata(
4431            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
4432            Arc::new(DatasetMetadata::default()),
4433        ));
4434        let evaluator = expr.load(&dataset).unwrap();
4435        let result = evaluator
4436            .evaluate_batch(&[1.1, 2.2], &[0, 2])
4437            .expect("evaluation should succeed");
4438        assert_eq!(result.len(), 2);
4439        assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
4440        assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
4441        let result_grad = evaluator
4442            .evaluate_gradient_batch(&[1.1, 2.2], &[0, 2])
4443            .expect("evaluation should succeed");
4444        assert_eq!(result_grad.len(), 2);
4445        assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
4446        assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
4447        assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
4448        assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
4449    }
4450
4451    #[test]
4452    fn test_load_compiles_expression_ir_once() {
4453        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4454            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4455        .norm_sqr();
4456        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4457        let evaluator = expr.load(&dataset).unwrap();
4458        assert!(evaluator.expression_slot_count() > 0);
4459    }
4460    #[test]
4461    fn test_expression_ir_value_matches_lowered_runtime() {
4462        let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4463            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4464            * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
4465        .conj()
4466        .norm_sqr();
4467        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4468        let evaluator = expr.load(&dataset).unwrap();
4469        let resources = evaluator.resources.read();
4470        let parameters = resources
4471            .parameter_map
4472            .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
4473            .expect("parameters should assemble");
4474        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4475        evaluator.fill_amplitude_values(
4476            &mut amplitude_values,
4477            resources.active_indices(),
4478            &parameters,
4479            &resources.caches[0],
4480        );
4481        let mut ir_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4482        let lowered_runtime = evaluator.lowered_runtime();
4483        let lowered_program = lowered_runtime.value_program();
4484        let mut lowered_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4485        let lowered_value =
4486            evaluator.evaluate_expression_value_with_scratch(&amplitude_values, &mut ir_slots);
4487        let direct_lowered_value =
4488            lowered_program.evaluate_into(&amplitude_values, &mut lowered_slots);
4489        let ir_value = evaluator
4490            .expression_ir()
4491            .evaluate_into(&amplitude_values, &mut ir_slots);
4492        assert_relative_eq!(lowered_value.re, direct_lowered_value.re);
4493        assert_relative_eq!(lowered_value.im, direct_lowered_value.im);
4494        assert_relative_eq!(lowered_value.re, ir_value.re);
4495        assert_relative_eq!(lowered_value.im, ir_value.im);
4496    }
4497    #[test]
4498    fn test_expression_ir_load_initializes_with_lowered_value_runtime() {
4499        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai"))
4500            .unwrap()
4501            .norm_sqr();
4502        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4503        let evaluator = expr.load(&dataset).unwrap();
4504        let lowered_runtime = evaluator.lowered_runtime();
4505        assert_eq!(
4506            lowered_runtime.value_program().kind(),
4507            lowered::LoweredProgramKind::Value
4508        );
4509        assert_eq!(
4510            lowered_runtime.gradient_program().kind(),
4511            lowered::LoweredProgramKind::Gradient
4512        );
4513        assert_eq!(
4514            lowered_runtime.value_gradient_program().kind(),
4515            lowered::LoweredProgramKind::ValueGradient
4516        );
4517    }
4518    #[test]
4519    fn test_expression_ir_gradient_matches_lowered_runtime() {
4520        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4521            * TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4522        .norm_sqr();
4523        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4524        let evaluator = expr.load(&dataset).unwrap();
4525        let resources = evaluator.resources.read();
4526        let parameters = resources
4527            .parameter_map
4528            .assemble(&[1.0, 0.25, -0.8, 0.5])
4529            .expect("parameters should assemble");
4530        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4531        evaluator.fill_amplitude_values(
4532            &mut amplitude_values,
4533            resources.active_indices(),
4534            &parameters,
4535            &resources.caches[0],
4536        );
4537        let mut active_mask = vec![false; evaluator.amplitudes.len()];
4538        for &index in resources.active_indices() {
4539            active_mask[index] = true;
4540        }
4541        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
4542            .map(|_| DVector::zeros(parameters.len()))
4543            .collect::<Vec<_>>();
4544        evaluator.fill_amplitude_gradients(
4545            &mut amplitude_gradients,
4546            &active_mask,
4547            &parameters,
4548            &resources.caches[0],
4549        );
4550        let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4551        let mut ir_gradient_slots: Vec<DVector<Complex64>> =
4552            (0..evaluator.expression_ir().node_count())
4553                .map(|_| DVector::zeros(parameters.len()))
4554                .collect();
4555        let lowered_runtime = evaluator.lowered_runtime();
4556        let lowered_program = lowered_runtime.gradient_program();
4557        let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4558        let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
4559            .scratch_slots())
4560            .map(|_| DVector::zeros(parameters.len()))
4561            .collect();
4562        let active_gradient = evaluator.evaluate_expression_gradient_with_scratch(
4563            &amplitude_values,
4564            &amplitude_gradients,
4565            &mut ir_value_slots,
4566            &mut ir_gradient_slots,
4567        );
4568        let ir_gradient = evaluator.expression_ir().evaluate_gradient_into(
4569            &amplitude_values,
4570            &amplitude_gradients,
4571            &mut ir_value_slots,
4572            &mut ir_gradient_slots,
4573        );
4574        let lowered_gradient = lowered_program.evaluate_gradient_into(
4575            &amplitude_values,
4576            &amplitude_gradients,
4577            &mut lowered_value_slots,
4578            &mut lowered_gradient_slots,
4579        );
4580        for (active, lowered) in active_gradient.iter().zip(lowered_gradient.iter()) {
4581            assert_relative_eq!(active.re, lowered.re);
4582            assert_relative_eq!(active.im, lowered.im);
4583        }
4584        for (lowered, ir) in lowered_gradient.iter().zip(ir_gradient.iter()) {
4585            assert_relative_eq!(lowered.re, ir.re);
4586            assert_relative_eq!(lowered.im, ir.im);
4587        }
4588    }
4589    #[test]
4590    fn test_expression_ir_value_gradient_matches_lowered_runtime() {
4591        let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4592            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4593            * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
4594        .norm_sqr();
4595        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4596        let evaluator = expr.load(&dataset).unwrap();
4597        let resources = evaluator.resources.read();
4598        let parameters = resources
4599            .parameter_map
4600            .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
4601            .expect("parameters should assemble");
4602        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4603        evaluator.fill_amplitude_values(
4604            &mut amplitude_values,
4605            resources.active_indices(),
4606            &parameters,
4607            &resources.caches[0],
4608        );
4609        let mut active_mask = vec![false; evaluator.amplitudes.len()];
4610        for &index in resources.active_indices() {
4611            active_mask[index] = true;
4612        }
4613        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
4614            .map(|_| DVector::zeros(parameters.len()))
4615            .collect::<Vec<_>>();
4616        evaluator.fill_amplitude_gradients(
4617            &mut amplitude_gradients,
4618            &active_mask,
4619            &parameters,
4620            &resources.caches[0],
4621        );
4622        let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4623        let mut ir_gradient_slots: Vec<DVector<Complex64>> =
4624            (0..evaluator.expression_ir().node_count())
4625                .map(|_| DVector::zeros(parameters.len()))
4626                .collect();
4627        let lowered_runtime = evaluator.lowered_runtime();
4628        let lowered_program = lowered_runtime.value_gradient_program();
4629        let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4630        let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
4631            .scratch_slots())
4632            .map(|_| DVector::zeros(parameters.len()))
4633            .collect();
4634
4635        let active_value_gradient = evaluator.evaluate_expression_value_gradient_with_scratch(
4636            &amplitude_values,
4637            &amplitude_gradients,
4638            &mut ir_value_slots,
4639            &mut ir_gradient_slots,
4640        );
4641        let ir_value_gradient = evaluator.expression_ir().evaluate_value_gradient_into(
4642            &amplitude_values,
4643            &amplitude_gradients,
4644            &mut ir_value_slots,
4645            &mut ir_gradient_slots,
4646        );
4647        let lowered_value_gradient = lowered_program.evaluate_value_gradient_into(
4648            &amplitude_values,
4649            &amplitude_gradients,
4650            &mut lowered_value_slots,
4651            &mut lowered_gradient_slots,
4652        );
4653
4654        assert_relative_eq!(active_value_gradient.0.re, lowered_value_gradient.0.re);
4655        assert_relative_eq!(active_value_gradient.0.im, lowered_value_gradient.0.im);
4656        for (active, lowered) in active_value_gradient
4657            .1
4658            .iter()
4659            .zip(lowered_value_gradient.1.iter())
4660        {
4661            assert_relative_eq!(active.re, lowered.re);
4662            assert_relative_eq!(active.im, lowered.im);
4663        }
4664        assert_relative_eq!(lowered_value_gradient.0.re, ir_value_gradient.0.re);
4665        assert_relative_eq!(lowered_value_gradient.0.im, ir_value_gradient.0.im);
4666        for (lowered, ir) in lowered_value_gradient
4667            .1
4668            .iter()
4669            .zip(ir_value_gradient.1.iter())
4670        {
4671            assert_relative_eq!(lowered.re, ir.re);
4672            assert_relative_eq!(lowered.im, ir.im);
4673        }
4674    }
4675    #[test]
4676    fn test_expression_runtime_diagnostics_reports_lowered_programs() {
4677        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4678        let evaluator = fixture
4679            .expression
4680            .load(&fixture.dataset)
4681            .expect("fixture evaluator should load");
4682
4683        let diagnostics = evaluator.expression_runtime_diagnostics();
4684        assert!(diagnostics.ir_planning_enabled);
4685        assert!(diagnostics.lowered_value_program_present);
4686        assert!(diagnostics.lowered_gradient_program_present);
4687        assert!(diagnostics.lowered_value_gradient_program_present);
4688        assert!(diagnostics.residual_runtime_present);
4689        assert_eq!(
4690            diagnostics.specialization_status,
4691            Some(ExpressionSpecializationStatus {
4692                origin: ExpressionSpecializationOrigin::InitialLoad,
4693            })
4694        );
4695    }
4696    #[test]
4697    fn test_expression_runtime_diagnostics_reports_specialization_origin() {
4698        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4699        let evaluator = fixture
4700            .expression
4701            .load(&fixture.dataset)
4702            .expect("fixture evaluator should load");
4703
4704        assert_eq!(
4705            evaluator
4706                .expression_runtime_diagnostics()
4707                .specialization_status,
4708            Some(ExpressionSpecializationStatus {
4709                origin: ExpressionSpecializationOrigin::InitialLoad,
4710            })
4711        );
4712
4713        evaluator.isolate_many(&["p"]);
4714        assert_eq!(
4715            evaluator
4716                .expression_runtime_diagnostics()
4717                .specialization_status,
4718            Some(ExpressionSpecializationStatus {
4719                origin: ExpressionSpecializationOrigin::CacheMissRebuild,
4720            })
4721        );
4722    }
4723    #[test]
4724    fn test_compiled_expression_display_reports_dag_refs() {
4725        let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4726        let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4727        let term = &a * &b;
4728        let expr = &term + &term;
4729        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4730        let evaluator = expr.load(&dataset).unwrap();
4731
4732        let compiled = evaluator.compiled_expression();
4733        let display = compiled.to_string();
4734
4735        assert_eq!(compiled.root(), compiled.nodes().len() - 1);
4736        assert!(display.contains("#"));
4737        assert!(display.contains("+"));
4738        assert!(display.contains("×"));
4739        assert!(display.contains("a(id=0)"));
4740        assert!(display.contains("b(id=1)"));
4741        assert!(display.contains("(ref)"));
4742    }
4743
4744    #[test]
4745    fn test_expression_compiled_expression_display_reports_dag_refs_without_loading() {
4746        let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4747        let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4748        let term = &a * &b;
4749        let expr = &term + &term;
4750
4751        let compiled = expr.compiled_expression();
4752        let display = compiled.to_string();
4753
4754        assert_eq!(compiled.root(), compiled.nodes().len() - 1);
4755        assert!(display.contains("#"));
4756        assert!(display.contains("+"));
4757        assert!(display.contains("×"));
4758        assert!(display.contains("a(id=0)"));
4759        assert!(display.contains("b(id=1)"));
4760        assert!(display.contains("(ref)"));
4761    }
4762
4763    #[test]
4764    fn test_compiled_expression_display_uses_current_active_mask() {
4765        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4766            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4767        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4768        let evaluator = expr.load(&dataset).unwrap();
4769        evaluator.deactivate("b");
4770
4771        let compiled = evaluator.compiled_expression().to_string();
4772
4773        assert!(compiled.contains("a(id=0)"));
4774        assert!(!compiled.contains("b(id=1)"));
4775        assert!(!compiled.contains("const 0"));
4776        assert!(!compiled.contains("+"));
4777    }
4778
4779    fn assert_compiled_single_amplitude(expr: &Expression, expected_label: &str) {
4780        let compiled = expr.compiled_expression();
4781        assert_eq!(compiled.nodes().len(), 1);
4782        assert_eq!(compiled.root(), 0);
4783        match &compiled.nodes()[0] {
4784            CompiledExpressionNode::Amplitude { index, name } => {
4785                assert_eq!(*index, 0);
4786                assert_eq!(name, expected_label);
4787            }
4788            node => panic!("expected one amplitude node, got {node:?}"),
4789        }
4790    }
4791
4792    fn assert_compiled_constant(expr: &Expression, expected: Complex64) {
4793        let compiled = expr.compiled_expression();
4794        assert_eq!(compiled.nodes().len(), 1);
4795        assert_eq!(compiled.root(), 0);
4796        match compiled.nodes()[0] {
4797            CompiledExpressionNode::Constant(value) => {
4798                assert_relative_eq!(value.re, expected.re);
4799                assert_relative_eq!(value.im, expected.im);
4800            }
4801            ref node => panic!("expected one constant node, got {node:?}"),
4802        }
4803    }
4804
4805    #[test]
4806    fn test_compiled_expression_simplifies_arithmetic_identities() {
4807        let amp = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4808        let zero = Expression::zero();
4809        let one = Expression::one();
4810
4811        assert_compiled_single_amplitude(&(&amp + &zero), "a");
4812        assert_compiled_single_amplitude(&(&zero + &amp), "a");
4813        assert_compiled_single_amplitude(&(&amp - &zero), "a");
4814        assert_compiled_single_amplitude(&(&amp * &one), "a");
4815        assert_compiled_single_amplitude(&(&one * &amp), "a");
4816        assert_compiled_single_amplitude(&(&amp / &one), "a");
4817        assert_compiled_single_amplitude(&amp.pow(&one), "a");
4818        assert_compiled_single_amplitude(&amp.powi(1), "a");
4819        assert_compiled_single_amplitude(&amp.powf(1.0), "a");
4820
4821        let times_zero = &amp * &zero;
4822        assert_compiled_constant(&times_zero, Complex64::ZERO);
4823        assert!(times_zero.parameters().contains_key("ar"));
4824        assert!(times_zero.parameters().contains_key("ai"));
4825
4826        assert_compiled_constant(&(&zero * &amp), Complex64::ZERO);
4827        assert_compiled_constant(&(&zero / &Expression::from(2.0)), Complex64::ZERO);
4828        assert_compiled_constant(&amp.powi(0), Complex64::ONE);
4829        assert_compiled_constant(
4830            &Expression::from(2.0).pow(&Expression::zero()),
4831            Complex64::ONE,
4832        );
4833        assert_compiled_constant(&Expression::from(2.0).powf(0.0), Complex64::ONE);
4834
4835        let unsafe_zero_division = (&zero / &amp).compiled_expression().to_string();
4836        assert!(unsafe_zero_division.contains("÷"));
4837        assert!(unsafe_zero_division.contains("a(id=0)"));
4838    }
4839
4840    #[test]
4841    fn test_compiled_expression_folds_unary_constant_functions() {
4842        assert_compiled_constant(&Expression::from(0.0).exp(), Complex64::ONE);
4843        assert_compiled_constant(&Expression::from(0.0).sin(), Complex64::ZERO);
4844        assert_compiled_constant(&Expression::from(0.0).cos(), Complex64::ONE);
4845        assert_compiled_constant(&Expression::from(1.0).log(), Complex64::ZERO);
4846        assert_compiled_constant(&Expression::from(4.0).sqrt(), Complex64::new(2.0, 0.0));
4847        assert_compiled_constant(&Expression::from(0.0).cis(), Complex64::ONE);
4848    }
4849
4850    #[test]
4851    fn test_evaluator_expression_reconstructs_expression() {
4852        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4853        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4854        let evaluator = expr.load(&dataset).unwrap();
4855
4856        assert_eq!(
4857            evaluator.expression().compiled_expression(),
4858            expr.compiled_expression()
4859        );
4860    }
4861
4862    #[test]
4863    fn test_active_mask_override_ignores_current_ir_specialization() {
4864        let expr = ComplexScalar::new("amp", parameter!("scale"), parameter!("amp_im", 0.0))
4865            .unwrap()
4866            .norm_sqr();
4867        let dataset = Arc::new(test_dataset());
4868        let evaluator = expr.load(&dataset).unwrap();
4869        let params = vec![2.0];
4870
4871        evaluator.deactivate("amp");
4872        assert_eq!(
4873            evaluator
4874                .evaluate(&params)
4875                .expect("evaluation should succeed")[0],
4876            Complex64::new(0.0, 0.0)
4877        );
4878
4879        let overridden = evaluator
4880            .evaluate_local_with_active_mask(&params, &[true])
4881            .unwrap();
4882        assert_eq!(overridden[0], Complex64::new(4.0, 0.0));
4883
4884        let overridden_fused = evaluator
4885            .evaluate_with_gradient_local_with_active_mask(&params, &[true])
4886            .unwrap();
4887        assert_eq!(overridden_fused[0].0, Complex64::new(4.0, 0.0));
4888        assert_eq!(overridden_fused[0].1[0], Complex64::new(4.0, 0.0));
4889    }
4890    #[test]
4891    fn test_expression_ir_dependence_diagnostics_surface() {
4892        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4893            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4894        .norm_sqr();
4895        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4896        let evaluator = expr.load(&dataset).unwrap();
4897        let annotations = evaluator
4898            .expression_node_dependence_annotations()
4899            .expect("annotations should exist");
4900        assert_eq!(annotations.len(), evaluator.expression_ir().node_count());
4901        assert!(annotations
4902            .iter()
4903            .all(|dependence| *dependence == ExpressionDependence::Mixed));
4904        assert_eq!(
4905            evaluator
4906                .expression_root_dependence()
4907                .expect("root dependence should exist"),
4908            ExpressionDependence::Mixed
4909        );
4910    }
4911    #[test]
4912    fn test_expression_ir_default_dependence_hint_is_mixed() {
4913        let expr = ComplexScalar::new("c", parameter!("cr"), parameter!("ci")).unwrap();
4914        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4915        let evaluator = expr.load(&dataset).unwrap();
4916        assert_eq!(
4917            evaluator
4918                .expression_root_dependence()
4919                .expect("root dependence should exist"),
4920            ExpressionDependence::Mixed
4921        );
4922    }
4923    #[test]
4924    fn test_expression_ir_parameter_only_dependence_hint_propagates() {
4925        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap();
4926        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4927        let evaluator = expr.load(&dataset).unwrap();
4928        assert_eq!(
4929            evaluator
4930                .expression_root_dependence()
4931                .expect("root dependence should exist"),
4932            ExpressionDependence::ParameterOnly
4933        );
4934    }
4935    #[test]
4936    fn test_expression_ir_cache_only_dependence_hint_propagates() {
4937        let expr = CacheOnlyScalar::new("k").unwrap();
4938        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4939        let evaluator = expr.load(&dataset).unwrap();
4940        assert_eq!(
4941            evaluator
4942                .expression_root_dependence()
4943                .expect("root dependence should exist"),
4944            ExpressionDependence::CacheOnly
4945        );
4946    }
4947    #[test]
4948    fn test_expression_ir_real_valued_hint_folds_imag_projection_to_zero() {
4949        let expr = ParameterOnlyScalar::new("p", parameter!("p"))
4950            .unwrap()
4951            .imag();
4952        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4953        let evaluator = expr.load(&dataset).unwrap();
4954        let ir = evaluator.expression_ir();
4955
4956        assert!(matches!(
4957            ir.nodes()[ir.root()],
4958            ir::IrNode::Constant(value) if value == Complex64::ZERO
4959        ));
4960        assert_eq!(
4961            evaluator
4962                .evaluate(&[2.5])
4963                .expect("evaluation should succeed")[0],
4964            Complex64::ZERO
4965        );
4966    }
4967    #[test]
4968    fn test_expression_ir_real_valued_hint_simplifies_conjugation() {
4969        let expr = ParameterOnlyScalar::new("p", parameter!("p"))
4970            .unwrap()
4971            .conj();
4972        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4973        let evaluator = expr.load(&dataset).unwrap();
4974        let ir = evaluator.expression_ir();
4975
4976        assert!(matches!(ir.nodes()[ir.root()], ir::IrNode::Amp(0)));
4977        assert_eq!(
4978            evaluator
4979                .evaluate(&[2.5])
4980                .expect("evaluation should succeed")[0],
4981            Complex64::new(2.5, 0.0)
4982        );
4983    }
4984    #[test]
4985    fn test_expression_ir_dependence_warnings_surface() {
4986        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
4987            + &CacheOnlyScalar::new("k").unwrap();
4988        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4989        let evaluator = expr.load(&dataset).unwrap();
4990        assert!(evaluator
4991            .expression_dependence_warnings()
4992            .expect("warnings should exist")
4993            .iter()
4994            .any(|warning| warning.contains("both ParameterOnly and CacheOnly")));
4995    }
4996    #[test]
4997    fn test_expression_ir_normalization_plan_explain_surface() {
4998        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
4999            * &CacheOnlyScalar::new("k").unwrap();
5000        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5001        let evaluator = expr.load(&dataset).unwrap();
5002        let explain = evaluator
5003            .expression_normalization_plan_explain()
5004            .expect("plan should exist");
5005        assert_eq!(explain.root_dependence, ExpressionDependence::Mixed);
5006        assert_eq!(explain.separable_mul_candidate_nodes.len(), 1);
5007        assert_eq!(
5008            explain.cached_separable_nodes,
5009            explain.separable_mul_candidate_nodes
5010        );
5011        assert!(explain.residual_terms.iter().all(|index| {
5012            !explain
5013                .separable_mul_candidate_nodes
5014                .iter()
5015                .any(|candidate| candidate == index)
5016        }));
5017    }
5018    #[test]
5019    fn test_expression_ir_normalization_execution_sets_surface() {
5020        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5021            * &CacheOnlyScalar::new("k").unwrap();
5022        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5023        let evaluator = expr.load(&dataset).unwrap();
5024        let sets = evaluator
5025            .expression_normalization_execution_sets()
5026            .expect("sets should exist");
5027        assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5028        assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5029        assert!(sets.residual_amplitudes.is_empty());
5030    }
5031    #[test]
5032    fn test_expression_ir_normalization_execution_sets_partial_surface() {
5033        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5034            * &CacheOnlyScalar::new("k").unwrap())
5035            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5036        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5037        let evaluator = expr.load(&dataset).unwrap();
5038        let sets = evaluator
5039            .expression_normalization_execution_sets()
5040            .expect("sets should exist");
5041        assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5042        assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5043        assert_eq!(sets.residual_amplitudes, vec![2]);
5044    }
5045    #[test]
5046    fn test_expression_ir_precomputed_cached_integrals_at_load() {
5047        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5048            * &CacheOnlyScalar::new("k").unwrap();
5049        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5050        let evaluator = expr.load(&dataset).unwrap();
5051        let precomputed = evaluator
5052            .expression_precomputed_cached_integrals()
5053            .expect("integrals should exist");
5054        assert_eq!(precomputed.len(), 1);
5055        let cache_reference = CacheOnlyScalar::new("k_ref")
5056            .unwrap()
5057            .load(&dataset)
5058            .unwrap();
5059        let cache_values = cache_reference
5060            .evaluate_local(&[])
5061            .expect("evaluation should succeed");
5062        let expected_weighted_sum = cache_values
5063            .iter()
5064            .zip(dataset.weights_local().iter())
5065            .fold(Complex64::ZERO, |acc, (value, event)| {
5066                acc + (*value * *event)
5067            });
5068        assert_relative_eq!(
5069            precomputed[0].weighted_cache_sum.re,
5070            expected_weighted_sum.re
5071        );
5072        assert_relative_eq!(
5073            precomputed[0].weighted_cache_sum.im,
5074            expected_weighted_sum.im
5075        );
5076    }
5077    #[test]
5078    fn test_expression_ir_precomputed_cached_integrals_empty_when_non_separable() {
5079        let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5080            * &CacheOnlyScalar::new("k").unwrap();
5081        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5082        let evaluator = expr.load(&dataset).unwrap();
5083        assert!(evaluator
5084            .expression_precomputed_cached_integrals()
5085            .expect("integrals should exist")
5086            .is_empty());
5087    }
5088    #[test]
5089    fn test_expression_ir_precomputed_cached_integrals_recompute_on_activation_change() {
5090        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5091            * &CacheOnlyScalar::new("k").unwrap();
5092        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5093        let evaluator = expr.load(&dataset).unwrap();
5094        assert_eq!(
5095            evaluator
5096                .expression_precomputed_cached_integrals()
5097                .expect("integrals should exist")
5098                .len(),
5099            1
5100        );
5101
5102        evaluator.isolate_many(&["p"]);
5103        assert!(evaluator
5104            .expression_precomputed_cached_integrals()
5105            .expect("integrals should exist")
5106            .is_empty());
5107    }
5108    #[test]
5109    fn test_expression_ir_precomputed_cached_integrals_recompute_on_dataset_change() {
5110        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5111            * &CacheOnlyScalar::new("k").unwrap();
5112        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5113        let mut evaluator = expr.load(&dataset).unwrap();
5114        drop(dataset);
5115        let before = evaluator
5116            .expression_precomputed_cached_integrals()
5117            .expect("integrals should exist");
5118        assert_eq!(before.len(), 1);
5119
5120        Arc::get_mut(&mut evaluator.dataset)
5121            .expect("evaluator should own dataset Arc in this test")
5122            .clear_events_local();
5123        let after = evaluator
5124            .expression_precomputed_cached_integrals()
5125            .expect("integrals should exist");
5126        assert_eq!(after.len(), 1);
5127        assert_eq!(after[0].weighted_cache_sum, Complex64::ZERO);
5128        assert!(before[0].weighted_cache_sum != after[0].weighted_cache_sum);
5129    }
5130    #[test]
5131    fn test_expression_ir_precomputed_cached_integral_gradient_terms_scale_by_cache_integrals() {
5132        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5133            * &CacheOnlyScalar::new("k").unwrap();
5134        let dataset = Arc::new(Dataset::new(vec![
5135            Arc::new(test_event()),
5136            Arc::new(test_event()),
5137        ]));
5138        let evaluator = expr.load(&dataset).unwrap();
5139        let cached_integrals = evaluator
5140            .expression_precomputed_cached_integrals()
5141            .expect("integrals should exist");
5142        assert_eq!(cached_integrals.len(), 1);
5143        let gradient_terms = evaluator
5144            .expression_precomputed_cached_integral_gradient_terms(&[1.25])
5145            .expect("evaluation should succeed");
5146        assert_eq!(gradient_terms.len(), 1);
5147        assert_eq!(gradient_terms[0].weighted_gradient.len(), 1);
5148        assert_relative_eq!(
5149            gradient_terms[0].weighted_gradient[0].re,
5150            cached_integrals[0].weighted_cache_sum.re,
5151            epsilon = 1e-6
5152        );
5153        assert_relative_eq!(
5154            gradient_terms[0].weighted_gradient[0].im,
5155            cached_integrals[0].weighted_cache_sum.im,
5156            epsilon = 1e-6
5157        );
5158    }
5159    #[test]
5160    fn test_expression_ir_precomputed_cached_integral_gradient_terms_empty_when_not_separable() {
5161        let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5162            * &CacheOnlyScalar::new("k").unwrap();
5163        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5164        let evaluator = expr.load(&dataset).unwrap();
5165        assert!(evaluator
5166            .expression_precomputed_cached_integral_gradient_terms(&[0.1, -0.2])
5167            .expect("evaluation should succeed")
5168            .is_empty());
5169    }
5170    #[test]
5171    fn test_expression_ir_lowered_cached_factor_programs_match_ir_cached_paths() {
5172        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5173            * &CacheOnlyScalar::new("k").unwrap())
5174            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5175        let dataset = Arc::new(test_dataset());
5176        let evaluator = expr.load(&dataset).unwrap();
5177        let resources = evaluator.resources.read();
5178        let state = evaluator
5179            .ensure_cached_integral_cache_state(&resources)
5180            .expect("state should be available");
5181        let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5182        let parameters = resources
5183            .parameter_map
5184            .assemble(&[0.55, 0.2, -0.15])
5185            .expect("parameters should assemble");
5186
5187        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5188        evaluator.fill_amplitude_values(
5189            &mut amplitude_values,
5190            &state.execution_sets.cached_parameter_amplitudes,
5191            &parameters,
5192            &resources.caches[0],
5193        );
5194        let cached_value_ir =
5195            evaluator.evaluate_cached_weighted_value_sum_ir(&state, &amplitude_values);
5196        let cached_value_lowered = evaluator
5197            .evaluate_cached_weighted_value_sum_lowered(
5198                &state,
5199                lowered_artifacts.as_ref(),
5200                &amplitude_values,
5201            )
5202            .expect("cached value lowering should succeed");
5203        assert_relative_eq!(cached_value_lowered, cached_value_ir, epsilon = 1e-12);
5204
5205        let mut cached_parameter_mask = vec![false; evaluator.amplitudes.len()];
5206        for &index in &state.execution_sets.cached_parameter_amplitudes {
5207            cached_parameter_mask[index] = true;
5208        }
5209        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5210            .map(|_| DVector::zeros(parameters.len()))
5211            .collect::<Vec<_>>();
5212        evaluator.fill_amplitude_gradients(
5213            &mut amplitude_gradients,
5214            &cached_parameter_mask,
5215            &parameters,
5216            &resources.caches[0],
5217        );
5218        let cached_gradient_ir = evaluator.evaluate_cached_weighted_gradient_sum_ir(
5219            &state,
5220            &amplitude_values,
5221            &amplitude_gradients,
5222            parameters.len(),
5223        );
5224        let cached_gradient_lowered = evaluator
5225            .evaluate_cached_weighted_gradient_sum_lowered(
5226                &state,
5227                lowered_artifacts.as_ref(),
5228                &amplitude_values,
5229                &amplitude_gradients,
5230                parameters.len(),
5231            )
5232            .expect("cached gradient lowering should succeed");
5233        for (lowered, ir) in cached_gradient_lowered
5234            .iter()
5235            .zip(cached_gradient_ir.iter())
5236        {
5237            assert_relative_eq!(*lowered, *ir, epsilon = 1e-12);
5238        }
5239    }
5240    #[test]
5241    fn test_expression_ir_lowered_residual_runtime_matches_zeroed_node_path() {
5242        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5243            * &CacheOnlyScalar::new("k").unwrap())
5244            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5245        let dataset = Arc::new(test_dataset());
5246        let evaluator = expr.load(&dataset).unwrap();
5247        let resources = evaluator.resources.read();
5248        let state = evaluator
5249            .ensure_cached_integral_cache_state(&resources)
5250            .expect("state should be available");
5251        let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5252        let parameters = resources
5253            .parameter_map
5254            .assemble(&[0.55, 0.2, -0.15])
5255            .expect("parameters should assemble");
5256
5257        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5258        evaluator.fill_amplitude_values(
5259            &mut amplitude_values,
5260            &state.execution_sets.residual_amplitudes,
5261            &parameters,
5262            &resources.caches[0],
5263        );
5264        let residual_value_ir = evaluator.evaluate_residual_value_ir(&state, &amplitude_values);
5265        let residual_program = lowered_artifacts
5266            .residual_runtime
5267            .as_ref()
5268            .map(|runtime| runtime.value_program())
5269            .expect("residual value lowering should succeed");
5270        let mut value_slots = vec![Complex64::ZERO; residual_program.scratch_slots()];
5271        let residual_value_lowered =
5272            residual_program.evaluate_into(&amplitude_values, &mut value_slots);
5273        assert_relative_eq!(
5274            residual_value_lowered.re,
5275            residual_value_ir.re,
5276            epsilon = 1e-12
5277        );
5278        assert_relative_eq!(
5279            residual_value_lowered.im,
5280            residual_value_ir.im,
5281            epsilon = 1e-12
5282        );
5283
5284        let mut residual_active_mask = vec![false; evaluator.amplitudes.len()];
5285        for &index in &state.execution_sets.residual_amplitudes {
5286            residual_active_mask[index] = true;
5287        }
5288        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5289            .map(|_| DVector::zeros(parameters.len()))
5290            .collect::<Vec<_>>();
5291        evaluator.fill_amplitude_gradients(
5292            &mut amplitude_gradients,
5293            &residual_active_mask,
5294            &parameters,
5295            &resources.caches[0],
5296        );
5297        let residual_gradient_ir = evaluator.evaluate_residual_gradient_ir(
5298            &state,
5299            &amplitude_values,
5300            &amplitude_gradients,
5301            parameters.len(),
5302        );
5303
5304        let program = lowered_artifacts
5305            .residual_runtime
5306            .as_ref()
5307            .map(|runtime| runtime.gradient_program())
5308            .expect("gradient lowering should succeed");
5309        let mut value_slots = vec![Complex64::ZERO; program.scratch_slots()];
5310        let mut gradient_slots = vec![Complex64::ZERO; program.scratch_slots() * parameters.len()];
5311        let residual_gradient_lowered = program.evaluate_gradient_into_flat(
5312            &amplitude_values,
5313            &amplitude_gradients,
5314            &mut value_slots,
5315            &mut gradient_slots,
5316            parameters.len(),
5317        );
5318
5319        for (lowered, ir) in residual_gradient_lowered
5320            .iter()
5321            .zip(residual_gradient_ir.iter())
5322        {
5323            assert_relative_eq!(lowered.re, ir.re, epsilon = 1e-12);
5324            assert_relative_eq!(lowered.im, ir.im, epsilon = 1e-12);
5325        }
5326    }
5327    #[test]
5328    fn test_expression_ir_reuses_lowered_artifacts_when_dataset_key_changes() {
5329        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5330            * &CacheOnlyScalar::new("k").unwrap())
5331            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5332        let dataset = Arc::new(test_dataset());
5333        let mut evaluator = expr.load(&dataset).unwrap();
5334        drop(dataset);
5335
5336        assert_eq!(evaluator.specialization_cache_len(), 1);
5337        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5338
5339        evaluator.reset_expression_compile_metrics();
5340        evaluator.reset_expression_specialization_metrics();
5341
5342        Arc::get_mut(&mut evaluator.dataset)
5343            .expect("evaluator should own dataset Arc in this test")
5344            .clear_events_local();
5345
5346        let cached_integrals = evaluator
5347            .expression_precomputed_cached_integrals()
5348            .expect("integrals should exist");
5349        assert_eq!(cached_integrals.len(), 1);
5350        assert_eq!(cached_integrals[0].weighted_cache_sum, Complex64::ZERO);
5351
5352        assert_eq!(evaluator.specialization_cache_len(), 2);
5353        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5354        assert_eq!(
5355            evaluator.expression_specialization_metrics(),
5356            ExpressionSpecializationMetrics {
5357                cache_hits: 0,
5358                cache_misses: 1,
5359            }
5360        );
5361
5362        let compile_metrics = evaluator.expression_compile_metrics();
5363        assert_eq!(compile_metrics.specialization_cache_hits, 0);
5364        assert_eq!(compile_metrics.specialization_cache_misses, 1);
5365        assert_eq!(compile_metrics.specialization_lowering_cache_hits, 1);
5366        assert_eq!(compile_metrics.specialization_lowering_cache_misses, 0);
5367        assert!(compile_metrics.specialization_ir_compile_nanos > 0);
5368        assert!(compile_metrics.specialization_cached_integrals_nanos > 0);
5369        assert_eq!(compile_metrics.specialization_lowering_nanos, 0);
5370    }
5371
5372    #[test]
5373    fn test_evaluate_weighted_gradient_sum_local_matches_eventwise_baseline() {
5374        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5375        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5376        let c1 = CacheOnlyScalar::new("c1").unwrap();
5377        let c2 = CacheOnlyScalar::new("c2").unwrap();
5378        let c3 = CacheOnlyScalar::new("c3").unwrap();
5379        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5380        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5381        let dataset = Arc::new(test_dataset());
5382        let evaluator = expr.load(&dataset).unwrap();
5383        assert_eq!(
5384            evaluator
5385                .expression_precomputed_cached_integrals()
5386                .expect("integrals should exist")
5387                .len(),
5388            2
5389        );
5390        let params = vec![0.2, -0.3, 1.1, -0.7];
5391        let expected = evaluator
5392            .evaluate_gradient_local(&params)
5393            .expect("evaluation should succeed")
5394            .iter()
5395            .zip(dataset.weights_local().iter())
5396            .fold(
5397                DVector::zeros(params.len()),
5398                |mut accum, (gradient, event)| {
5399                    accum += gradient.map(|value| value.re).scale(*event);
5400                    accum
5401                },
5402            );
5403        let actual = evaluator
5404            .evaluate_weighted_gradient_sum_local(&params)
5405            .expect("evaluation should succeed");
5406        for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
5407            assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
5408        }
5409    }
5410
5411    #[test]
5412    fn test_evaluate_weighted_value_sum_local_matches_eventwise_baseline() {
5413        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5414        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5415        let c1 = CacheOnlyScalar::new("c1").unwrap();
5416        let c2 = CacheOnlyScalar::new("c2").unwrap();
5417        let c3 = CacheOnlyScalar::new("c3").unwrap();
5418        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5419        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5420        let dataset = Arc::new(test_dataset());
5421        let evaluator = expr.load(&dataset).unwrap();
5422        assert_eq!(
5423            evaluator
5424                .expression_precomputed_cached_integrals()
5425                .expect("integrals should exist")
5426                .len(),
5427            2
5428        );
5429        let params = vec![0.2, -0.3, 1.1, -0.7];
5430        let expected = evaluator
5431            .evaluate_local(&params)
5432            .expect("evaluation should succeed")
5433            .iter()
5434            .zip(dataset.weights_local().iter())
5435            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5436        let actual = evaluator
5437            .evaluate_weighted_value_sum_local(&params)
5438            .expect("evaluation should succeed");
5439        assert_relative_eq!(actual, expected, epsilon = 1e-10);
5440    }
5441
5442    #[test]
5443    fn test_weighted_sums_match_hardcoded_reference_values() {
5444        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5445        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5446        let c1 = CacheOnlyScalar::new("c1").unwrap();
5447        let c2 = CacheOnlyScalar::new("c2").unwrap();
5448        let c3 = CacheOnlyScalar::new("c3").unwrap();
5449        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5450        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5451
5452        let metadata = Arc::new(DatasetMetadata::default());
5453        let dataset = Arc::new(Dataset::new_with_metadata(
5454            vec![
5455                Arc::new(EventData {
5456                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5457                    aux: vec![],
5458                    weight: 0.5,
5459                }),
5460                Arc::new(EventData {
5461                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5462                    aux: vec![],
5463                    weight: -1.25,
5464                }),
5465                Arc::new(EventData {
5466                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5467                    aux: vec![],
5468                    weight: 2.0,
5469                }),
5470            ],
5471            metadata,
5472        ));
5473        let evaluator = expr.load(&dataset).unwrap();
5474        let params = vec![0.7, -1.1, 0.9, -0.4];
5475
5476        let weighted_value_sum = evaluator
5477            .evaluate_weighted_value_sum_local(&params)
5478            .expect("evaluation should succeed");
5479        assert_relative_eq!(weighted_value_sum, 22.7725, epsilon = 1e-12);
5480
5481        let weighted_gradient_sum = evaluator
5482            .evaluate_weighted_gradient_sum_local(&params)
5483            .expect("evaluation should succeed");
5484        let free_parameters = evaluator
5485            .parameters()
5486            .free()
5487            .names()
5488            .into_iter()
5489            .map(|name| name.to_string())
5490            .collect::<Vec<_>>();
5491        assert_eq!(free_parameters, vec!["p1", "p2", "m1r", "m1i"]);
5492        let expected_gradient = [43.925, 7.25, 28.525, 0.0];
5493        assert_eq!(weighted_gradient_sum.len(), expected_gradient.len());
5494        for (actual, expected) in weighted_gradient_sum.iter().zip(expected_gradient.iter()) {
5495            assert_relative_eq!(*actual, *expected, epsilon = 1e-9);
5496        }
5497    }
5498    #[test]
5499    fn test_evaluate_weighted_gradient_sum_local_respects_signed_cached_terms() {
5500        let expr = Expression::one()
5501            - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5502                * &CacheOnlyScalar::new("k").unwrap());
5503        let dataset = Arc::new(test_dataset());
5504        let evaluator = expr.load(&dataset).unwrap();
5505        assert_eq!(
5506            evaluator
5507                .expression_precomputed_cached_integrals()
5508                .expect("integrals should exist")
5509                .len(),
5510            1
5511        );
5512        assert_eq!(
5513            evaluator
5514                .expression_precomputed_cached_integrals()
5515                .expect("integrals should exist")[0]
5516                .coefficient,
5517            -1
5518        );
5519        let params = vec![0.75];
5520        let expected = evaluator
5521            .evaluate_gradient_local(&params)
5522            .expect("evaluation should succeed")
5523            .iter()
5524            .zip(dataset.weights_local().iter())
5525            .fold(
5526                DVector::zeros(params.len()),
5527                |mut accum, (gradient, event)| {
5528                    accum += gradient.map(|value| value.re).scale(*event);
5529                    accum
5530                },
5531            );
5532        let actual = evaluator
5533            .evaluate_weighted_gradient_sum_local(&params)
5534            .expect("evaluation should succeed");
5535        for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
5536            assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
5537        }
5538    }
5539    #[test]
5540    fn test_evaluate_weighted_value_sum_local_respects_signed_cached_terms() {
5541        let expr = Expression::one()
5542            - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5543                * &CacheOnlyScalar::new("k").unwrap());
5544        let dataset = Arc::new(test_dataset());
5545        let evaluator = expr.load(&dataset).unwrap();
5546        assert_eq!(
5547            evaluator
5548                .expression_precomputed_cached_integrals()
5549                .expect("integrals should exist")
5550                .len(),
5551            1
5552        );
5553        assert_eq!(
5554            evaluator
5555                .expression_precomputed_cached_integrals()
5556                .expect("integrals should exist")[0]
5557                .coefficient,
5558            -1
5559        );
5560        let params = vec![0.75];
5561        let expected = evaluator
5562            .evaluate_local(&params)
5563            .expect("evaluation should succeed")
5564            .iter()
5565            .zip(dataset.weights_local().iter())
5566            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5567        let actual = evaluator
5568            .evaluate_weighted_value_sum_local(&params)
5569            .expect("evaluation should succeed");
5570        assert_relative_eq!(actual, expected, epsilon = 1e-10);
5571    }
5572    #[test]
5573    fn test_expression_ir_diagnostics_follow_activation_changes() {
5574        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5575            * &CacheOnlyScalar::new("k").unwrap())
5576            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5577        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5578        let evaluator = expr.load(&dataset).unwrap();
5579
5580        let all_active = evaluator
5581            .expression_normalization_plan_explain()
5582            .expect("plan should exist");
5583        assert_eq!(all_active.cached_separable_nodes.len(), 1);
5584        assert_eq!(
5585            evaluator
5586                .expression_root_dependence()
5587                .expect("root dependence should exist"),
5588            ExpressionDependence::Mixed
5589        );
5590
5591        evaluator.isolate_many(&["p"]);
5592        let param_only = evaluator
5593            .expression_normalization_plan_explain()
5594            .expect("plan should exist");
5595        assert!(param_only.cached_separable_nodes.is_empty());
5596        assert_eq!(
5597            evaluator
5598                .expression_root_dependence()
5599                .expect("root dependence should exist"),
5600            ExpressionDependence::ParameterOnly
5601        );
5602    }
5603    #[test]
5604    fn test_expression_ir_specialization_cache_reuses_prior_mask_specializations() {
5605        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5606            * &CacheOnlyScalar::new("k").unwrap())
5607            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5608        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5609        let evaluator = expr.load(&dataset).unwrap();
5610
5611        let initial_compile_metrics = evaluator.expression_compile_metrics();
5612        assert!(initial_compile_metrics.initial_ir_compile_nanos > 0);
5613        assert!(initial_compile_metrics.initial_cached_integrals_nanos > 0);
5614        assert!(initial_compile_metrics.initial_lowering_nanos > 0);
5615        assert_eq!(initial_compile_metrics.specialization_cache_hits, 0);
5616        assert_eq!(initial_compile_metrics.specialization_cache_misses, 0);
5617        assert_eq!(
5618            initial_compile_metrics.specialization_lowering_cache_hits,
5619            0
5620        );
5621        assert_eq!(
5622            initial_compile_metrics.specialization_lowering_cache_misses,
5623            1
5624        );
5625
5626        assert_eq!(evaluator.specialization_cache_len(), 1);
5627        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5628        assert_eq!(
5629            evaluator.expression_specialization_metrics(),
5630            ExpressionSpecializationMetrics {
5631                cache_hits: 0,
5632                cache_misses: 1,
5633            }
5634        );
5635        let all_active_cached_integrals = evaluator
5636            .expression_precomputed_cached_integrals()
5637            .expect("integrals should exist");
5638
5639        evaluator.isolate_many(&["p"]);
5640        assert_eq!(evaluator.specialization_cache_len(), 2);
5641        assert_eq!(
5642            evaluator.expression_specialization_metrics(),
5643            ExpressionSpecializationMetrics {
5644                cache_hits: 0,
5645                cache_misses: 2,
5646            }
5647        );
5648        let after_cache_miss_metrics = evaluator.expression_compile_metrics();
5649        assert_eq!(after_cache_miss_metrics.specialization_cache_hits, 0);
5650        assert_eq!(after_cache_miss_metrics.specialization_cache_misses, 1);
5651        assert_eq!(
5652            after_cache_miss_metrics.specialization_lowering_cache_hits,
5653            0
5654        );
5655        assert_eq!(
5656            after_cache_miss_metrics.specialization_lowering_cache_misses,
5657            2
5658        );
5659        assert!(after_cache_miss_metrics.specialization_ir_compile_nanos > 0);
5660        assert!(after_cache_miss_metrics.specialization_cached_integrals_nanos > 0);
5661        assert!(after_cache_miss_metrics.specialization_lowering_nanos > 0);
5662        assert!(evaluator
5663            .expression_precomputed_cached_integrals()
5664            .expect("integrals should exist")
5665            .is_empty());
5666
5667        evaluator.activate_many(&["k", "m"]);
5668        assert_eq!(evaluator.specialization_cache_len(), 2);
5669        assert_eq!(
5670            evaluator.expression_specialization_metrics(),
5671            ExpressionSpecializationMetrics {
5672                cache_hits: 1,
5673                cache_misses: 2,
5674            }
5675        );
5676        assert_eq!(
5677            evaluator
5678                .expression_precomputed_cached_integrals()
5679                .expect("integrals should exist"),
5680            all_active_cached_integrals
5681        );
5682        let after_cache_hit_metrics = evaluator.expression_compile_metrics();
5683        assert_eq!(after_cache_hit_metrics.specialization_cache_hits, 1);
5684        assert_eq!(after_cache_hit_metrics.specialization_cache_misses, 1);
5685        assert_eq!(
5686            after_cache_hit_metrics.specialization_lowering_cache_hits,
5687            0
5688        );
5689        assert_eq!(
5690            after_cache_hit_metrics.specialization_lowering_cache_misses,
5691            2
5692        );
5693        assert!(after_cache_hit_metrics.specialization_cache_restore_nanos > 0);
5694    }
5695
5696    #[test]
5697    fn test_weighted_sums_match_baseline_after_activation_changes() {
5698        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5699        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5700        let c1 = CacheOnlyScalar::new("c1").unwrap();
5701        let c2 = CacheOnlyScalar::new("c2").unwrap();
5702        let c3 = CacheOnlyScalar::new("c3").unwrap();
5703        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5704        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5705        let dataset = Arc::new(test_dataset());
5706        let evaluator = expr.load(&dataset).unwrap();
5707        let params = vec![0.2, -0.3, 1.1, -0.7];
5708
5709        evaluator.isolate_many(&["p1", "c1", "m1", "c3"]);
5710
5711        let expected_value = evaluator
5712            .evaluate_local(&params)
5713            .expect("evaluation should succeed")
5714            .iter()
5715            .zip(dataset.weights_local().iter())
5716            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5717        assert_relative_eq!(
5718            evaluator
5719                .evaluate_weighted_value_sum_local(&params)
5720                .expect("evaluation should succeed"),
5721            expected_value,
5722            epsilon = 1e-10
5723        );
5724    }
5725
5726    #[test]
5727    fn test_evaluate_local_does_not_depend_on_dataset_rows() {
5728        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5729            .unwrap()
5730            .norm_sqr();
5731        let mut event1 = test_event();
5732        event1.p4s[0].t = 7.5;
5733        let mut event2 = test_event();
5734        event2.p4s[0].t = 8.25;
5735        let mut event3 = test_event();
5736        event3.p4s[0].t = 9.0;
5737        let dataset = Arc::new(Dataset::new_with_metadata(
5738            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5739            Arc::new(DatasetMetadata::default()),
5740        ));
5741        let mut evaluator = expr.load(&dataset).unwrap();
5742        drop(dataset);
5743        let expected_len = evaluator.resources.read().caches.len();
5744        Arc::get_mut(&mut evaluator.dataset)
5745            .expect("evaluator should own dataset Arc in this test")
5746            .clear_events_local();
5747        let cached = evaluator
5748            .evaluate_local(&[1.25, -0.75])
5749            .expect("evaluation should succeed");
5750        assert_eq!(cached.len(), expected_len);
5751    }
5752
5753    #[test]
5754    fn test_evaluate_gradient_local_does_not_depend_on_dataset_rows() {
5755        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5756            .unwrap()
5757            .norm_sqr();
5758        let mut event1 = test_event();
5759        event1.p4s[0].t = 7.5;
5760        let mut event2 = test_event();
5761        event2.p4s[0].t = 8.25;
5762        let mut event3 = test_event();
5763        event3.p4s[0].t = 9.0;
5764        let dataset = Arc::new(Dataset::new_with_metadata(
5765            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5766            Arc::new(DatasetMetadata::default()),
5767        ));
5768        let mut evaluator = expr.load(&dataset).unwrap();
5769        drop(dataset);
5770        let expected_len = evaluator.resources.read().caches.len();
5771        Arc::get_mut(&mut evaluator.dataset)
5772            .expect("evaluator should own dataset Arc in this test")
5773            .clear_events_local();
5774        let cached = evaluator
5775            .evaluate_gradient_local(&[1.25, -0.75])
5776            .expect("evaluation should succeed");
5777        assert_eq!(cached.len(), expected_len);
5778    }
5779
5780    #[test]
5781    fn test_evaluate_with_gradient_local_matches_separate_paths() {
5782        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5783            .unwrap()
5784            .norm_sqr();
5785        let dataset = Arc::new(Dataset::new(vec![
5786            Arc::new(test_event()),
5787            Arc::new(test_event()),
5788            Arc::new(test_event()),
5789        ]));
5790        let evaluator = expr.load(&dataset).unwrap();
5791        let params = [1.25, -0.75];
5792        let values = evaluator
5793            .evaluate_local(&params)
5794            .expect("evaluation should succeed");
5795        let gradients = evaluator
5796            .evaluate_gradient_local(&params)
5797            .expect("evaluation should succeed");
5798        let fused = evaluator
5799            .evaluate_with_gradient_local(&params)
5800            .expect("evaluation should succeed");
5801        assert_eq!(fused.len(), values.len());
5802        assert_eq!(fused.len(), gradients.len());
5803        for ((value_gradient, value), gradient) in
5804            fused.iter().zip(values.iter()).zip(gradients.iter())
5805        {
5806            let (fused_value, fused_gradient) = value_gradient;
5807            assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
5808            assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
5809            assert_eq!(fused_gradient.len(), gradient.len());
5810            for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
5811                assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
5812                assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
5813            }
5814        }
5815    }
5816
5817    #[test]
5818    fn test_evaluate_with_gradient_batch_local_matches_separate_paths() {
5819        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5820            .unwrap()
5821            .norm_sqr();
5822        let dataset = Arc::new(Dataset::new(vec![
5823            Arc::new(test_event()),
5824            Arc::new(test_event()),
5825            Arc::new(test_event()),
5826            Arc::new(test_event()),
5827        ]));
5828        let evaluator = expr.load(&dataset).unwrap();
5829        let params = [0.5, -1.25];
5830        let indices = vec![0, 2, 3];
5831        let values = evaluator
5832            .evaluate_batch_local(&params, &indices)
5833            .expect("evaluation should succeed");
5834        let gradients = evaluator
5835            .evaluate_gradient_batch_local(&params, &indices)
5836            .expect("evaluation should succeed");
5837        let fused = evaluator
5838            .evaluate_with_gradient_batch_local(&params, &indices)
5839            .expect("evaluation should succeed");
5840        assert_eq!(fused.len(), values.len());
5841        assert_eq!(fused.len(), gradients.len());
5842        for ((value_gradient, value), gradient) in
5843            fused.iter().zip(values.iter()).zip(gradients.iter())
5844        {
5845            let (fused_value, fused_gradient) = value_gradient;
5846            assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
5847            assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
5848            assert_eq!(fused_gradient.len(), gradient.len());
5849            for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
5850                assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
5851                assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
5852            }
5853        }
5854    }
5855
5856    #[test]
5857    fn test_precompute_all_columnar_populates_cache() {
5858        let mut event1 = test_event();
5859        event1.p4s[0].t = 7.5;
5860        let mut event2 = test_event();
5861        event2.p4s[0].t = 8.25;
5862        let mut event3 = test_event();
5863        event3.p4s[0].t = 9.0;
5864        let dataset = Dataset::new_with_metadata(
5865            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5866            Arc::new(DatasetMetadata::default()),
5867        );
5868        let mut amplitude = TestAmplitude {
5869            tags: Tags::new(["test"]),
5870            re: parameter!("real"),
5871            pid_re: ParameterID::default(),
5872            im: parameter!("imag"),
5873            pid_im: ParameterID::default(),
5874            beam_energy: Default::default(),
5875        };
5876        let mut resources = Resources::default();
5877        amplitude
5878            .register(&mut resources)
5879            .expect("test amplitude should register");
5880        resources.reserve_cache(dataset.n_events());
5881        amplitude.precompute_all(&dataset, &mut resources);
5882        for cache in &resources.caches {
5883            assert!(cache.get_scalar(amplitude.beam_energy) > 0.0);
5884        }
5885    }
5886
5887    #[cfg(feature = "mpi")]
5888    #[mpi_test(np = [2])]
5889    fn test_load_reserves_local_cache_size_in_mpi() {
5890        use crate::mpi::{finalize_mpi, get_world, use_mpi};
5891
5892        use_mpi(true);
5893        assert!(get_world().is_some(), "MPI world should be initialized");
5894
5895        let expr = ComplexScalar::new(
5896            "constant",
5897            parameter!("const_re", 2.0),
5898            parameter!("const_im", 3.0),
5899        )
5900        .expect("constant amplitude should construct");
5901        let events = vec![
5902            Arc::new(test_event()),
5903            Arc::new(test_event()),
5904            Arc::new(test_event()),
5905            Arc::new(test_event()),
5906        ];
5907        let dataset = Arc::new(Dataset::new_with_metadata(
5908            events,
5909            Arc::new(DatasetMetadata::default()),
5910        ));
5911        let evaluator = expr.load(&dataset).expect("evaluator should load");
5912        let local_events = dataset.n_events_local();
5913        let cache_len = evaluator.resources.read().caches.len();
5914
5915        assert_eq!(
5916            cache_len, local_events,
5917            "cache length must match local event count under MPI"
5918        );
5919        finalize_mpi();
5920    }
5921
5922    #[cfg(feature = "mpi")]
5923    #[mpi_test(np = [2])]
5924    fn test_expression_ir_cached_integrals_are_rank_local_in_mpi() {
5925        use mpi::{collective::SystemOperation, topology::Communicator, traits::*};
5926
5927        use crate::mpi::{finalize_mpi, get_world, use_mpi};
5928
5929        use_mpi(true);
5930        let world = get_world().expect("MPI world should be initialized");
5931
5932        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5933            * &CacheOnlyScalar::new("k").unwrap();
5934        let events = vec![
5935            Arc::new(EventData {
5936                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
5937                aux: vec![],
5938                weight: 0.5,
5939            }),
5940            Arc::new(EventData {
5941                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5942                aux: vec![],
5943                weight: 1.0,
5944            }),
5945            Arc::new(EventData {
5946                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5947                aux: vec![],
5948                weight: 1.5,
5949            }),
5950            Arc::new(EventData {
5951                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
5952                aux: vec![],
5953                weight: 2.0,
5954            }),
5955            Arc::new(EventData {
5956                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5957                aux: vec![],
5958                weight: 2.5,
5959            }),
5960            Arc::new(EventData {
5961                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
5962                aux: vec![],
5963                weight: 3.0,
5964            }),
5965            Arc::new(EventData {
5966                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
5967                aux: vec![],
5968                weight: 3.5,
5969            }),
5970            Arc::new(EventData {
5971                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
5972                aux: vec![],
5973                weight: 4.0,
5974            }),
5975        ];
5976        let dataset = Arc::new(Dataset::new_with_metadata(
5977            events,
5978            Arc::new(DatasetMetadata::default()),
5979        ));
5980        let evaluator = expr.load(&dataset).expect("evaluator should load");
5981        let cached_integrals = evaluator
5982            .expression_precomputed_cached_integrals()
5983            .expect("integrals should exist");
5984        assert_eq!(cached_integrals.len(), 1);
5985
5986        let local_expected =
5987            dataset
5988                .weights_local()
5989                .iter()
5990                .enumerate()
5991                .fold(0.0, |acc, (index, weight)| {
5992                    let event = dataset.event_local(index).expect("event should exist");
5993                    acc + *weight * event.p4_at(0).e()
5994                });
5995        let cached_local = cached_integrals[0].weighted_cache_sum;
5996        assert_relative_eq!(cached_local.re, local_expected, epsilon = 1e-12);
5997        assert_relative_eq!(cached_local.im, 0.0, epsilon = 1e-12);
5998
5999        let weighted_value_sum = evaluator
6000            .evaluate_weighted_value_sum_local(&[2.0])
6001            .expect("evaluate should succeed");
6002        assert_relative_eq!(weighted_value_sum, 2.0 * local_expected, epsilon = 1e-10);
6003
6004        let mut global_expected = 0.0;
6005        world.all_reduce_into(
6006            &local_expected,
6007            &mut global_expected,
6008            SystemOperation::sum(),
6009        );
6010        if world.size() > 1 {
6011            assert!(
6012                (cached_local.re - global_expected).abs() > 1e-12,
6013                "cached integral should remain rank-local before MPI reduction"
6014            );
6015        }
6016        finalize_mpi();
6017    }
6018
6019    #[cfg(feature = "mpi")]
6020    #[mpi_test(np = [2])]
6021    fn test_expression_ir_weighted_sum_mpi_matches_global_eventwise_baseline() {
6022        use mpi::{collective::SystemOperation, traits::*};
6023
6024        use crate::mpi::{finalize_mpi, get_world, use_mpi};
6025
6026        use_mpi(true);
6027        let world = get_world().expect("MPI world should be initialized");
6028
6029        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6030        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6031        let c1 = CacheOnlyScalar::new("c1").unwrap();
6032        let c2 = CacheOnlyScalar::new("c2").unwrap();
6033        let c3 = CacheOnlyScalar::new("c3").unwrap();
6034        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6035        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6036        let events = vec![
6037            Arc::new(EventData {
6038                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
6039                aux: vec![],
6040                weight: 0.5,
6041            }),
6042            Arc::new(EventData {
6043                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6044                aux: vec![],
6045                weight: -1.25,
6046            }),
6047            Arc::new(EventData {
6048                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6049                aux: vec![],
6050                weight: 0.75,
6051            }),
6052            Arc::new(EventData {
6053                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
6054                aux: vec![],
6055                weight: 1.5,
6056            }),
6057            Arc::new(EventData {
6058                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6059                aux: vec![],
6060                weight: 2.25,
6061            }),
6062            Arc::new(EventData {
6063                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
6064                aux: vec![],
6065                weight: -0.5,
6066            }),
6067            Arc::new(EventData {
6068                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
6069                aux: vec![],
6070                weight: 3.5,
6071            }),
6072            Arc::new(EventData {
6073                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
6074                aux: vec![],
6075                weight: 1.25,
6076            }),
6077        ];
6078        let dataset = Arc::new(Dataset::new_with_metadata(
6079            events,
6080            Arc::new(DatasetMetadata::default()),
6081        ));
6082        let evaluator = expr.load(&dataset).expect("evaluator should load");
6083        let params = vec![0.2, -0.3, 1.1, -0.7];
6084
6085        let local_expected_value = evaluator
6086            .evaluate_local(&params)
6087            .expect("evaluate should succeed")
6088            .iter()
6089            .zip(dataset.weights_local().iter())
6090            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6091        let mut global_expected_value = 0.0;
6092        world.all_reduce_into(
6093            &local_expected_value,
6094            &mut global_expected_value,
6095            SystemOperation::sum(),
6096        );
6097        let mpi_value = evaluator
6098            .evaluate_weighted_value_sum_mpi(&params, &world)
6099            .expect("evaluate should succeed");
6100        assert_relative_eq!(mpi_value, global_expected_value, epsilon = 1e-10);
6101
6102        let local_expected_gradient = evaluator
6103            .evaluate_gradient_local(&params)
6104            .expect("evaluate should succeed")
6105            .iter()
6106            .zip(dataset.weights_local().iter())
6107            .fold(
6108                DVector::zeros(params.len()),
6109                |mut accum, (gradient, event)| {
6110                    accum += gradient.map(|value| value.re).scale(*event);
6111                    accum
6112                },
6113            );
6114        let mut global_expected_gradient = vec![0.0; local_expected_gradient.len()];
6115        world.all_reduce_into(
6116            local_expected_gradient.as_slice(),
6117            &mut global_expected_gradient,
6118            SystemOperation::sum(),
6119        );
6120        let mpi_gradient = evaluator
6121            .evaluate_weighted_gradient_sum_mpi(&params, &world)
6122            .expect("evaluate should succeed");
6123        for (actual, expected) in mpi_gradient.iter().zip(global_expected_gradient.iter()) {
6124            assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
6125        }
6126
6127        finalize_mpi();
6128    }
6129
6130    #[test]
6131    fn test_evaluate_local_succeeds_for_constant_amplitude() {
6132        let expr = ComplexScalar::new(
6133            "constant",
6134            parameter!("const_re", 2.0),
6135            parameter!("const_im", 3.0),
6136        )
6137        .unwrap();
6138        let dataset = Arc::new(Dataset::new_with_metadata(
6139            vec![Arc::new(test_event())],
6140            Arc::new(DatasetMetadata::default()),
6141        ));
6142        let evaluator = expr.load(&dataset).unwrap();
6143        let values = evaluator
6144            .evaluate_local(&[])
6145            .expect("evaluation should succeed");
6146        assert_eq!(values.len(), 1);
6147        let gradients = evaluator
6148            .evaluate_gradient_local(&[])
6149            .expect("evaluation should succeed");
6150        assert_eq!(gradients.len(), 1);
6151    }
6152
6153    #[test]
6154    fn test_constant_amplitude() {
6155        let expr = ComplexScalar::new(
6156            "constant",
6157            parameter!("const_re", 2.0),
6158            parameter!("const_im", 3.0),
6159        )
6160        .unwrap();
6161        let dataset = Arc::new(Dataset::new_with_metadata(
6162            vec![Arc::new(test_event())],
6163            Arc::new(DatasetMetadata::default()),
6164        ));
6165        let evaluator = expr.load(&dataset).unwrap();
6166        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6167        assert_eq!(result[0], Complex64::new(2.0, 3.0));
6168    }
6169
6170    #[test]
6171    fn test_parametric_amplitude() {
6172        let expr = ComplexScalar::new(
6173            "parametric",
6174            parameter!("test_param_re"),
6175            parameter!("test_param_im"),
6176        )
6177        .unwrap();
6178        let dataset = Arc::new(test_dataset());
6179        let evaluator = expr.load(&dataset).unwrap();
6180        let result = evaluator
6181            .evaluate(&[2.0, 3.0])
6182            .expect("evaluation should succeed");
6183        assert_eq!(result[0], Complex64::new(2.0, 3.0));
6184    }
6185
6186    #[test]
6187    fn test_expression_operations() {
6188        let expr1 = ComplexScalar::new(
6189            "const1",
6190            parameter!("const1_re", 2.0),
6191            parameter!("const1_im", 0.0),
6192        )
6193        .unwrap();
6194        let expr2 = ComplexScalar::new(
6195            "const2",
6196            parameter!("const2_re", 0.0),
6197            parameter!("const2_im", 1.0),
6198        )
6199        .unwrap();
6200        let expr3 = ComplexScalar::new(
6201            "const3",
6202            parameter!("const3_re", 3.0),
6203            parameter!("const3_im", 4.0),
6204        )
6205        .unwrap();
6206
6207        let dataset = Arc::new(test_dataset());
6208
6209        // Test (amp) addition
6210        let expr_add = &expr1 + &expr2;
6211        let result_add = expr_add
6212            .load(&dataset)
6213            .unwrap()
6214            .evaluate(&[])
6215            .expect("evaluation should succeed");
6216        assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
6217
6218        // Test (amp) subtraction
6219        let expr_sub = &expr1 - &expr2;
6220        let result_sub = expr_sub
6221            .load(&dataset)
6222            .unwrap()
6223            .evaluate(&[])
6224            .expect("evaluation should succeed");
6225        assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
6226
6227        // Test (amp) multiplication
6228        let expr_mul = &expr1 * &expr2;
6229        let result_mul = expr_mul
6230            .load(&dataset)
6231            .unwrap()
6232            .evaluate(&[])
6233            .expect("evaluation should succeed");
6234        assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
6235
6236        // Test (amp) division
6237        let expr_div = &expr1 / &expr3;
6238        let result_div = expr_div
6239            .load(&dataset)
6240            .unwrap()
6241            .evaluate(&[])
6242            .expect("evaluation should succeed");
6243        assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
6244
6245        // Test (amp) neg
6246        let expr_neg = -&expr3;
6247        let result_neg = expr_neg
6248            .load(&dataset)
6249            .unwrap()
6250            .evaluate(&[])
6251            .expect("evaluation should succeed");
6252        assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
6253
6254        // Test (expr) addition
6255        let expr_add2 = &expr_add + &expr_mul;
6256        let result_add2 = expr_add2
6257            .load(&dataset)
6258            .unwrap()
6259            .evaluate(&[])
6260            .expect("evaluation should succeed");
6261        assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
6262
6263        // Test (expr) subtraction
6264        let expr_sub2 = &expr_add - &expr_mul;
6265        let result_sub2 = expr_sub2
6266            .load(&dataset)
6267            .unwrap()
6268            .evaluate(&[])
6269            .expect("evaluation should succeed");
6270        assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
6271
6272        // Test (expr) multiplication
6273        let expr_mul2 = &expr_add * &expr_mul;
6274        let result_mul2 = expr_mul2
6275            .load(&dataset)
6276            .unwrap()
6277            .evaluate(&[])
6278            .expect("evaluation should succeed");
6279        assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
6280
6281        // Test (expr) division
6282        let expr_div2 = &expr_add / &expr_add2;
6283        let result_div2 = expr_div2
6284            .load(&dataset)
6285            .unwrap()
6286            .evaluate(&[])
6287            .expect("evaluation should succeed");
6288        assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
6289
6290        // Test (expr) neg
6291        let expr_neg2 = -&expr_mul2;
6292        let result_neg2 = expr_neg2
6293            .load(&dataset)
6294            .unwrap()
6295            .evaluate(&[])
6296            .expect("evaluation should succeed");
6297        assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
6298
6299        // Test (amp) real
6300        let expr_real = expr3.real();
6301        let result_real = expr_real
6302            .load(&dataset)
6303            .unwrap()
6304            .evaluate(&[])
6305            .expect("evaluation should succeed");
6306        assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
6307
6308        // Test (expr) real
6309        let expr_mul2_real = expr_mul2.real();
6310        let result_mul2_real = expr_mul2_real
6311            .load(&dataset)
6312            .unwrap()
6313            .evaluate(&[])
6314            .expect("evaluation should succeed");
6315        assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
6316
6317        // Test (amp) imag
6318        let expr_imag = expr3.imag();
6319        let result_imag = expr_imag
6320            .load(&dataset)
6321            .unwrap()
6322            .evaluate(&[])
6323            .expect("evaluation should succeed");
6324        assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
6325
6326        // Test (expr) imag
6327        let expr_mul2_imag = expr_mul2.imag();
6328        let result_mul2_imag = expr_mul2_imag
6329            .load(&dataset)
6330            .unwrap()
6331            .evaluate(&[])
6332            .expect("evaluation should succeed");
6333        assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
6334
6335        // Test (amp) conj
6336        let expr_conj = expr3.conj();
6337        let result_conj = expr_conj
6338            .load(&dataset)
6339            .unwrap()
6340            .evaluate(&[])
6341            .expect("evaluation should succeed");
6342        assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
6343
6344        // Test (expr) conj
6345        let expr_mul2_conj = expr_mul2.conj();
6346        let result_mul2_conj = expr_mul2_conj
6347            .load(&dataset)
6348            .unwrap()
6349            .evaluate(&[])
6350            .expect("evaluation should succeed");
6351        assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
6352
6353        // Test (amp) norm_sqr
6354        let expr_norm = expr1.norm_sqr();
6355        let result_norm = expr_norm
6356            .load(&dataset)
6357            .unwrap()
6358            .evaluate(&[])
6359            .expect("evaluation should succeed");
6360        assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
6361
6362        // Test (expr) norm_sqr
6363        let expr_mul2_norm = expr_mul2.norm_sqr();
6364        let result_mul2_norm = expr_mul2_norm
6365            .load(&dataset)
6366            .unwrap()
6367            .evaluate(&[])
6368            .expect("evaluation should succeed");
6369        assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
6370    }
6371
6372    #[test]
6373    fn test_amplitude_activation() {
6374        let expr1 = ComplexScalar::new(
6375            "const1",
6376            parameter!("const1_re_act", 1.0),
6377            parameter!("const1_im_act", 0.0),
6378        )
6379        .unwrap();
6380        let expr2 = ComplexScalar::new(
6381            "const2",
6382            parameter!("const2_re_act", 2.0),
6383            parameter!("const2_im_act", 0.0),
6384        )
6385        .unwrap();
6386
6387        let dataset = Arc::new(test_dataset());
6388        let expr = &expr1 + &expr2;
6389        let evaluator = expr.load(&dataset).unwrap();
6390
6391        // Test initial state (all active)
6392        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6393        assert_eq!(result[0], Complex64::new(3.0, 0.0));
6394
6395        // Test deactivation
6396        evaluator.deactivate_strict("const1").unwrap();
6397        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6398        assert_eq!(result[0], Complex64::new(2.0, 0.0));
6399
6400        // Test isolation
6401        evaluator.isolate_strict("const1").unwrap();
6402        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6403        assert_eq!(result[0], Complex64::new(1.0, 0.0));
6404
6405        // Test reactivation
6406        evaluator.activate_all();
6407        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6408        assert_eq!(result[0], Complex64::new(3.0, 0.0));
6409    }
6410
6411    #[test]
6412    fn test_gradient() {
6413        let expr1 = ComplexScalar::new(
6414            "parametric_1",
6415            parameter!("test_param_re_1"),
6416            parameter!("test_param_im_1"),
6417        )
6418        .unwrap();
6419        let expr2 = ComplexScalar::new(
6420            "parametric_2",
6421            parameter!("test_param_re_2"),
6422            parameter!("test_param_im_2"),
6423        )
6424        .unwrap();
6425
6426        let dataset = Arc::new(test_dataset());
6427        let params = vec![2.0, 3.0, 4.0, 5.0];
6428
6429        let expr = &expr1 + &expr2;
6430        let evaluator = expr.load(&dataset).unwrap();
6431
6432        let gradient = evaluator
6433            .evaluate_gradient(&params)
6434            .expect("evaluation should succeed");
6435
6436        assert_relative_eq!(gradient[0][0].re, 1.0);
6437        assert_relative_eq!(gradient[0][0].im, 0.0);
6438        assert_relative_eq!(gradient[0][1].re, 0.0);
6439        assert_relative_eq!(gradient[0][1].im, 1.0);
6440        assert_relative_eq!(gradient[0][2].re, 1.0);
6441        assert_relative_eq!(gradient[0][2].im, 0.0);
6442        assert_relative_eq!(gradient[0][3].re, 0.0);
6443        assert_relative_eq!(gradient[0][3].im, 1.0);
6444
6445        let expr = &expr1 - &expr2;
6446        let evaluator = expr.load(&dataset).unwrap();
6447
6448        let gradient = evaluator
6449            .evaluate_gradient(&params)
6450            .expect("evaluation should succeed");
6451
6452        assert_relative_eq!(gradient[0][0].re, 1.0);
6453        assert_relative_eq!(gradient[0][0].im, 0.0);
6454        assert_relative_eq!(gradient[0][1].re, 0.0);
6455        assert_relative_eq!(gradient[0][1].im, 1.0);
6456        assert_relative_eq!(gradient[0][2].re, -1.0);
6457        assert_relative_eq!(gradient[0][2].im, 0.0);
6458        assert_relative_eq!(gradient[0][3].re, 0.0);
6459        assert_relative_eq!(gradient[0][3].im, -1.0);
6460
6461        let expr = &expr1 * &expr2;
6462        let evaluator = expr.load(&dataset).unwrap();
6463
6464        let gradient = evaluator
6465            .evaluate_gradient(&params)
6466            .expect("evaluation should succeed");
6467
6468        assert_relative_eq!(gradient[0][0].re, 4.0);
6469        assert_relative_eq!(gradient[0][0].im, 5.0);
6470        assert_relative_eq!(gradient[0][1].re, -5.0);
6471        assert_relative_eq!(gradient[0][1].im, 4.0);
6472        assert_relative_eq!(gradient[0][2].re, 2.0);
6473        assert_relative_eq!(gradient[0][2].im, 3.0);
6474        assert_relative_eq!(gradient[0][3].re, -3.0);
6475        assert_relative_eq!(gradient[0][3].im, 2.0);
6476
6477        let expr = &expr1 / &expr2;
6478        let evaluator = expr.load(&dataset).unwrap();
6479
6480        let gradient = evaluator
6481            .evaluate_gradient(&params)
6482            .expect("evaluation should succeed");
6483
6484        assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
6485        assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
6486        assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
6487        assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
6488        assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
6489        assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
6490        assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
6491        assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
6492
6493        let expr = -(&expr1 * &expr2);
6494        let evaluator = expr.load(&dataset).unwrap();
6495
6496        let gradient = evaluator
6497            .evaluate_gradient(&params)
6498            .expect("evaluation should succeed");
6499
6500        assert_relative_eq!(gradient[0][0].re, -4.0);
6501        assert_relative_eq!(gradient[0][0].im, -5.0);
6502        assert_relative_eq!(gradient[0][1].re, 5.0);
6503        assert_relative_eq!(gradient[0][1].im, -4.0);
6504        assert_relative_eq!(gradient[0][2].re, -2.0);
6505        assert_relative_eq!(gradient[0][2].im, -3.0);
6506        assert_relative_eq!(gradient[0][3].re, 3.0);
6507        assert_relative_eq!(gradient[0][3].im, -2.0);
6508
6509        let expr = (&expr1 * &expr2).real();
6510        let evaluator = expr.load(&dataset).unwrap();
6511
6512        let gradient = evaluator
6513            .evaluate_gradient(&params)
6514            .expect("evaluation should succeed");
6515
6516        assert_relative_eq!(gradient[0][0].re, 4.0);
6517        assert_relative_eq!(gradient[0][0].im, 0.0);
6518        assert_relative_eq!(gradient[0][1].re, -5.0);
6519        assert_relative_eq!(gradient[0][1].im, 0.0);
6520        assert_relative_eq!(gradient[0][2].re, 2.0);
6521        assert_relative_eq!(gradient[0][2].im, 0.0);
6522        assert_relative_eq!(gradient[0][3].re, -3.0);
6523        assert_relative_eq!(gradient[0][3].im, 0.0);
6524
6525        let expr = (&expr1 * &expr2).imag();
6526        let evaluator = expr.load(&dataset).unwrap();
6527
6528        let gradient = evaluator
6529            .evaluate_gradient(&params)
6530            .expect("evaluation should succeed");
6531
6532        assert_relative_eq!(gradient[0][0].re, 5.0);
6533        assert_relative_eq!(gradient[0][0].im, 0.0);
6534        assert_relative_eq!(gradient[0][1].re, 4.0);
6535        assert_relative_eq!(gradient[0][1].im, 0.0);
6536        assert_relative_eq!(gradient[0][2].re, 3.0);
6537        assert_relative_eq!(gradient[0][2].im, 0.0);
6538        assert_relative_eq!(gradient[0][3].re, 2.0);
6539        assert_relative_eq!(gradient[0][3].im, 0.0);
6540
6541        let expr = (&expr1 * &expr2).conj();
6542        let evaluator = expr.load(&dataset).unwrap();
6543
6544        let gradient = evaluator
6545            .evaluate_gradient(&params)
6546            .expect("evaluation should succeed");
6547
6548        assert_relative_eq!(gradient[0][0].re, 4.0);
6549        assert_relative_eq!(gradient[0][0].im, -5.0);
6550        assert_relative_eq!(gradient[0][1].re, -5.0);
6551        assert_relative_eq!(gradient[0][1].im, -4.0);
6552        assert_relative_eq!(gradient[0][2].re, 2.0);
6553        assert_relative_eq!(gradient[0][2].im, -3.0);
6554        assert_relative_eq!(gradient[0][3].re, -3.0);
6555        assert_relative_eq!(gradient[0][3].im, -2.0);
6556
6557        let expr = (&expr1 * &expr2).norm_sqr();
6558        let evaluator = expr.load(&dataset).unwrap();
6559
6560        let gradient = evaluator
6561            .evaluate_gradient(&params)
6562            .expect("evaluation should succeed");
6563
6564        assert_relative_eq!(gradient[0][0].re, 164.0);
6565        assert_relative_eq!(gradient[0][0].im, 0.0);
6566        assert_relative_eq!(gradient[0][1].re, 246.0);
6567        assert_relative_eq!(gradient[0][1].im, 0.0);
6568        assert_relative_eq!(gradient[0][2].re, 104.0);
6569        assert_relative_eq!(gradient[0][2].im, 0.0);
6570        assert_relative_eq!(gradient[0][3].re, 130.0);
6571        assert_relative_eq!(gradient[0][3].im, 0.0);
6572    }
6573
6574    #[test]
6575    fn test_expression_function_gradients() {
6576        let expr1 = ComplexScalar::new(
6577            "function_parametric_1",
6578            parameter!("function_test_param_re_1"),
6579            parameter!("function_test_param_im_1"),
6580        )
6581        .unwrap();
6582        let expr2 = ComplexScalar::new(
6583            "function_parametric_2",
6584            parameter!("function_test_param_re_2"),
6585            parameter!("function_test_param_im_2"),
6586        )
6587        .unwrap();
6588
6589        let sin = expr1.sin();
6590        let cos = expr1.cos();
6591        let trig = &sin * &cos;
6592        let pow = expr1.pow(&expr2);
6593        let mut expr = expr1.sqrt();
6594        expr = &expr + &expr1.exp();
6595        expr = &expr + &expr1.powi(2);
6596        expr = &expr + &expr1.powf(1.7);
6597        expr = &expr + &trig;
6598        expr = &expr + &expr1.log();
6599        expr = &expr + &expr1.cis();
6600        expr = &expr + &pow;
6601
6602        let dataset = Arc::new(test_dataset());
6603        let evaluator = expr.load(&dataset).unwrap();
6604        let params = vec![2.0, 0.5, 1.2, -0.3];
6605        let gradient = evaluator
6606            .evaluate_gradient(&params)
6607            .expect("evaluation should succeed");
6608        let eps = 1e-6;
6609
6610        for param_index in 0..params.len() {
6611            let mut plus = params.clone();
6612            plus[param_index] += eps;
6613            let mut minus = params.clone();
6614            minus[param_index] -= eps;
6615            let finite_diff = (evaluator
6616                .evaluate(&plus)
6617                .expect("evaluation should succeed")[0]
6618                - evaluator
6619                    .evaluate(&minus)
6620                    .expect("evaluation should succeed")[0])
6621                / Complex64::new(2.0 * eps, 0.0);
6622
6623            assert_relative_eq!(
6624                gradient[0][param_index].re,
6625                finite_diff.re,
6626                epsilon = 1e-6,
6627                max_relative = 1e-6
6628            );
6629            assert_relative_eq!(
6630                gradient[0][param_index].im,
6631                finite_diff.im,
6632                epsilon = 1e-6,
6633                max_relative = 1e-6
6634            );
6635        }
6636    }
6637
6638    #[test]
6639    fn test_zeros_and_ones() {
6640        let amp = ComplexScalar::new(
6641            "parametric",
6642            parameter!("test_param_re"),
6643            parameter!("fixed_two", 2.0),
6644        )
6645        .unwrap();
6646        let dataset = Arc::new(test_dataset());
6647        let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
6648        let evaluator = expr.load(&dataset).unwrap();
6649
6650        let params = vec![2.0];
6651        let value = evaluator
6652            .evaluate(&params)
6653            .expect("evaluation should succeed");
6654        let gradient = evaluator
6655            .evaluate_gradient(&params)
6656            .expect("evaluation should succeed");
6657
6658        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the value should be x^2 + 4
6659        assert_relative_eq!(value[0].re, 8.0);
6660        assert_relative_eq!(value[0].im, 0.0);
6661
6662        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the derivative should be 2x
6663        assert_relative_eq!(gradient[0][0].re, 4.0);
6664        assert_relative_eq!(gradient[0][0].im, 0.0);
6665    }
6666    #[test]
6667    fn test_default_build_uses_lowered_expression_runtime() {
6668        let expr = ComplexScalar::new(
6669            "opt_in_gate",
6670            parameter!("opt_in_gate_re", 2.0),
6671            parameter!("opt_in_gate_im", 0.0),
6672        )
6673        .unwrap()
6674        .norm_sqr();
6675        let dataset = Arc::new(test_dataset());
6676        let evaluator = expr.load(&dataset).unwrap();
6677
6678        let diagnostics = evaluator.expression_runtime_diagnostics();
6679        assert!(diagnostics.ir_planning_enabled);
6680        assert!(diagnostics.lowered_value_program_present);
6681        assert!(diagnostics.lowered_gradient_program_present);
6682        assert!(diagnostics.lowered_value_gradient_program_present);
6683        assert_eq!(
6684            evaluator.evaluate(&[]).expect("evaluation should succeed")[0],
6685            Complex64::new(4.0, 0.0)
6686        );
6687    }
6688
6689    #[test]
6690    fn parameter_name_only_creates_free_parameter() {
6691        let p = parameter!("mass");
6692
6693        assert_eq!(p.name(), "mass");
6694        assert_eq!(p.fixed(), None);
6695        assert_eq!(p.initial(), None);
6696        assert_eq!(p.bounds(), (None, None));
6697        assert_eq!(p.unit(), None);
6698        assert_eq!(p.latex(), None);
6699        assert_eq!(p.description(), None);
6700        assert!(p.is_free());
6701        assert!(!p.is_fixed());
6702    }
6703
6704    #[test]
6705    fn parameter_name_and_value_creates_fixed_parameter() {
6706        let p = parameter!("width", 0.15);
6707
6708        assert_eq!(p.name(), "width");
6709        assert_eq!(p.fixed(), Some(0.15));
6710        assert_eq!(p.initial(), Some(0.15));
6711        assert!(p.is_fixed());
6712        assert!(!p.is_free());
6713    }
6714
6715    #[test]
6716    fn keyword_initial_sets_initial_only() {
6717        let p = parameter!("alpha", initial: 1.25);
6718
6719        assert_eq!(p.name(), "alpha");
6720        assert_eq!(p.fixed(), None);
6721        assert_eq!(p.initial(), Some(1.25));
6722        assert_eq!(p.bounds(), (None, None));
6723        assert!(p.is_free());
6724    }
6725
6726    #[test]
6727    fn keyword_fixed_sets_fixed_and_initial() {
6728        let p = parameter!("beta", fixed: 2.5);
6729
6730        assert_eq!(p.name(), "beta");
6731        assert_eq!(p.fixed(), Some(2.5));
6732        assert_eq!(p.initial(), Some(2.5));
6733        assert!(p.is_fixed());
6734    }
6735
6736    #[test]
6737    fn bounds_accept_plain_numbers() {
6738        let p = parameter!("x", bounds: (0.0, 10.0));
6739
6740        assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
6741    }
6742
6743    #[test]
6744    fn bounds_accept_none_and_number() {
6745        let p = parameter!("x", bounds: (None, 10.0));
6746
6747        assert_eq!(p.bounds(), (None, Some(10.0)));
6748    }
6749
6750    #[test]
6751    fn bounds_accept_number_and_none() {
6752        let p = parameter!("x", bounds: (-1.0, None));
6753
6754        assert_eq!(p.bounds(), (Some(-1.0), None));
6755    }
6756
6757    #[test]
6758    fn bounds_accept_both_none() {
6759        let p = parameter!("x", bounds: (None, None));
6760
6761        assert_eq!(p.bounds(), (None, None));
6762    }
6763
6764    #[test]
6765    fn bounds_accept_arbitrary_expressions() {
6766        let lo = 1.0;
6767        let hi = 2.0 * 3.0;
6768        let p = parameter!("x", bounds: (lo - 0.5, hi));
6769
6770        assert_eq!(p.bounds(), (Some(0.5), Some(6.0)));
6771    }
6772
6773    #[test]
6774    fn multiple_keyword_arguments_work_together() {
6775        let p = parameter!(
6776            "gamma",
6777            initial: 1.0,
6778            bounds: (0.0, 5.0),
6779            unit: "GeV",
6780            latex: r"\gamma",
6781            description: "test parameter",
6782        );
6783
6784        assert_eq!(p.name(), "gamma");
6785        assert_eq!(p.fixed(), None);
6786        assert_eq!(p.initial(), Some(1.0));
6787        assert_eq!(p.bounds(), (Some(0.0), Some(5.0)));
6788        assert_eq!(p.unit().as_deref(), Some("GeV"));
6789        assert_eq!(p.latex().as_deref(), Some(r"\gamma"));
6790        assert_eq!(p.description().as_deref(), Some("test parameter"));
6791    }
6792
6793    #[test]
6794    fn fixed_can_be_combined_with_other_fields() {
6795        let p = parameter!(
6796            "delta",
6797            fixed: 3.0,
6798            bounds: (0.0, 10.0),
6799            unit: "rad",
6800        );
6801
6802        assert_eq!(p.name(), "delta");
6803        assert_eq!(p.fixed(), Some(3.0));
6804        assert_eq!(p.initial(), Some(3.0));
6805        assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
6806        assert_eq!(p.unit().as_deref(), Some("rad"));
6807    }
6808
6809    #[test]
6810    fn trailing_comma_is_accepted() {
6811        let p = parameter!(
6812            "eps",
6813            initial: 0.5,
6814            bounds: (None, 1.0),
6815            unit: "arb",
6816        );
6817
6818        assert_eq!(p.initial(), Some(0.5));
6819        assert_eq!(p.bounds(), (None, Some(1.0)));
6820        assert_eq!(p.unit().as_deref(), Some("arb"));
6821    }
6822
6823    #[test]
6824    fn test_parameter_registration() {
6825        let expr = ComplexScalar::new(
6826            "parametric",
6827            parameter!("test_param_re"),
6828            parameter!("fixed_two", 2.0),
6829        )
6830        .unwrap();
6831        let parameters = expr.parameters().free().names();
6832        assert_eq!(parameters.len(), 1);
6833        assert_eq!(parameters[0], "test_param_re");
6834    }
6835
6836    #[test]
6837    fn test_duplicate_amplitude_tag_registration_is_allowed() {
6838        let amp1 = ComplexScalar::new(
6839            "same_name",
6840            parameter!("dup_re1", 1.0),
6841            parameter!("dup_im1", 0.0),
6842        )
6843        .unwrap();
6844        let amp2 = ComplexScalar::new(
6845            "same_name",
6846            parameter!("dup_re2", 2.0),
6847            parameter!("dup_im2", 0.0),
6848        )
6849        .unwrap();
6850        let expr = amp1 + amp2;
6851        assert_eq!(
6852            expr.parameters().fixed().names(),
6853            vec!["dup_re1", "dup_im1", "dup_re2", "dup_im2"]
6854        );
6855    }
6856
6857    #[test]
6858    fn test_tree_printing() {
6859        let amp1 = ComplexScalar::new(
6860            "parametric_1",
6861            parameter!("test_param_re_1"),
6862            parameter!("test_param_im_1"),
6863        )
6864        .unwrap();
6865        let amp2 = ComplexScalar::new(
6866            "parametric_2",
6867            parameter!("test_param_re_2"),
6868            parameter!("test_param_im_2"),
6869        )
6870        .unwrap();
6871        let expr =
6872            &amp1.real() + &amp2.conj().imag() + Expression::one() * Complex64::new(-1.4, 2.0)
6873                - Expression::zero() / 1.0
6874                + (&amp1 * &amp2).norm_sqr();
6875        assert_eq!(
6876            expr.to_string(),
6877            concat!(
6878                "+\n",
6879                "├─ -\n",
6880                "│  ├─ +\n",
6881                "│  │  ├─ +\n",
6882                "│  │  │  ├─ Re\n",
6883                "│  │  │  │  └─ parametric_1(id=0)\n",
6884                "│  │  │  └─ Im\n",
6885                "│  │  │     └─ *\n",
6886                "│  │  │        └─ parametric_2(id=1)\n",
6887                "│  │  └─ ×\n",
6888                "│  │     ├─ 1 (exact)\n",
6889                "│  │     └─ -1.4+2i\n",
6890                "│  └─ ÷\n",
6891                "│     ├─ 0 (exact)\n",
6892                "│     └─ 1 (exact)\n",
6893                "└─ NormSqr\n",
6894                "   └─ ×\n",
6895                "      ├─ parametric_1(id=0)\n",
6896                "      └─ parametric_2(id=1)\n",
6897            )
6898        );
6899    }
6900}