Skip to main content

laddu_core/expression/
evaluator.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    sync::{
5        atomic::{AtomicU64, Ordering},
6        Arc,
7    },
8    time::Instant,
9};
10
11use auto_ops::*;
12use nalgebra::DVector;
13use num::complex::Complex64;
14use parking_lot::RwLock;
15#[cfg(feature = "rayon")]
16use rayon::prelude::*;
17use serde::{Deserialize, Serialize};
18
19use super::{ir, lowered};
20
21static COMPUTE_AMPLITUDE_COUNTER: AtomicU64 = AtomicU64::new(0);
22static AMPLITUDE_USE_SITE_COUNTER: AtomicU64 = AtomicU64::new(0);
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
25struct ComputeAmplitudeId(u64);
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
28struct AmplitudeUseSiteId(u64);
29
30fn next_compute_amplitude_id() -> ComputeAmplitudeId {
31    ComputeAmplitudeId(COMPUTE_AMPLITUDE_COUNTER.fetch_add(1, Ordering::Relaxed))
32}
33
34fn next_amplitude_use_site_id() -> AmplitudeUseSiteId {
35    AmplitudeUseSiteId(AMPLITUDE_USE_SITE_COUNTER.fetch_add(1, Ordering::Relaxed))
36}
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38/// Dependence classification used by expression-IR diagnostics.
39pub enum ExpressionDependence {
40    /// Depends only on fixed/free parameter values.
41    ParameterOnly,
42    /// Depends only on event-local cached values.
43    CacheOnly,
44    /// Depends on both parameter values and cached event values.
45    Mixed,
46}
47impl From<ir::DependenceClass> for ExpressionDependence {
48    fn from(value: ir::DependenceClass) -> Self {
49        match value {
50            ir::DependenceClass::ParameterOnly => Self::ParameterOnly,
51            ir::DependenceClass::CacheOnly => Self::CacheOnly,
52            ir::DependenceClass::Mixed => Self::Mixed,
53        }
54    }
55}
56#[derive(Clone, Debug, PartialEq, Eq)]
57/// Explain/debug view of the IR normalization planning decomposition.
58pub struct NormalizationPlanExplain {
59    /// Dependence classification at the expression root.
60    pub root_dependence: ExpressionDependence,
61    /// Warning-level diagnostics collected during planning.
62    pub warnings: Vec<String>,
63    /// Candidate multiply node indices identified as separable.
64    pub separable_mul_candidate_nodes: Vec<usize>,
65    /// Candidate separable node indices selected for caching.
66    pub cached_separable_nodes: Vec<usize>,
67    /// Node indices planned for residual per-event evaluation.
68    pub residual_terms: Vec<usize>,
69}
70#[derive(Clone, Debug, PartialEq, Eq)]
71/// Explain/debug view of amplitude execution sets used by normalization evaluation.
72pub struct NormalizationExecutionSetsExplain {
73    /// Amplitudes required to evaluate parameter factors for cached separable terms.
74    pub cached_parameter_amplitudes: Vec<usize>,
75    /// Amplitudes required to evaluate cache factors for cached separable terms.
76    pub cached_cache_amplitudes: Vec<usize>,
77    /// Amplitudes required for residual (non-cached) normalization evaluation.
78    pub residual_amplitudes: Vec<usize>,
79}
80#[derive(Clone, Debug, PartialEq)]
81/// Load-time precomputed integral metadata for a separable cached term.
82pub struct PrecomputedCachedIntegral {
83    /// Node index of the separable multiplication term.
84    pub mul_node_index: usize,
85    /// Node index of the parameter-dependent factor.
86    pub parameter_node_index: usize,
87    /// Node index of the cache-dependent factor.
88    pub cache_node_index: usize,
89    /// Signed extraction coefficient induced by Add/Sub/Neg ancestry to the root.
90    pub coefficient: i32,
91    /// Weighted sum over local events of the cache-dependent factor.
92    pub weighted_cache_sum: Complex64,
93}
94#[derive(Clone, Debug, PartialEq)]
95/// Parameter-gradient contribution for a load-time precomputed cached integral term.
96pub struct PrecomputedCachedIntegralGradientTerm {
97    /// Node index of the separable multiplication term.
98    pub mul_node_index: usize,
99    /// Node index of the parameter-dependent factor.
100    pub parameter_node_index: usize,
101    /// Node index of the cache-dependent factor.
102    pub cache_node_index: usize,
103    /// Signed extraction coefficient induced by Add/Sub/Neg ancestry to the root.
104    pub coefficient: i32,
105    /// Gradient contribution `(d/dp parameter_factor) * weighted_cache_sum`.
106    pub weighted_gradient: DVector<Complex64>,
107}
108#[derive(Clone, Debug, PartialEq, Eq, Hash)]
109struct CachedIntegralCacheKey {
110    active_mask: Vec<bool>,
111    n_events_local: usize,
112    weights_local_len: usize,
113    weighted_sum_bits: u64,
114    weights_ptr: usize,
115}
116#[derive(Clone, Debug)]
117struct CachedIntegralCacheState {
118    key: CachedIntegralCacheKey,
119    expression_ir: ir::ExpressionIR,
120    values: Vec<PrecomputedCachedIntegral>,
121    execution_sets: ir::NormalizationExecutionSets,
122}
123#[derive(Clone, Debug)]
124struct LoweredArtifactCacheState {
125    parameter_node_indices: Vec<usize>,
126    mul_node_indices: Vec<usize>,
127    lowered_parameter_factors: Vec<Option<lowered::LoweredFactorRuntime>>,
128    residual_runtime: Option<lowered::LoweredExpressionRuntime>,
129    lowered_runtime: lowered::LoweredExpressionRuntime,
130}
131#[derive(Clone)]
132struct ExpressionSpecializationState {
133    cached_integrals: Arc<CachedIntegralCacheState>,
134    lowered_artifacts: Arc<LoweredArtifactCacheState>,
135}
136#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
137/// Debug/benchmark counters for active-mask specialization reuse under `expression-ir`.
138pub struct ExpressionSpecializationMetrics {
139    /// Number of specialization cache hits served without recompilation.
140    pub cache_hits: usize,
141    /// Number of specialization cache misses that required a fresh compile/lower pass.
142    pub cache_misses: usize,
143}
144#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
145/// Staged compile/lowering metrics for expression-IR construction and specialization refreshes.
146pub struct ExpressionCompileMetrics {
147    /// Nanoseconds spent compiling the semantic expression tree into IR during initial load.
148    pub initial_ir_compile_nanos: u64,
149    /// Nanoseconds spent precomputing cached-integral planning artifacts during initial load.
150    pub initial_cached_integrals_nanos: u64,
151    /// Nanoseconds spent lowering IR-derived runtimes during initial load.
152    pub initial_lowering_nanos: u64,
153    /// Number of specialization cache hits restored without recompilation.
154    pub specialization_cache_hits: usize,
155    /// Number of specialization cache misses that required recompilation.
156    pub specialization_cache_misses: usize,
157    /// Accumulated nanoseconds spent compiling active-mask-specialized IR after initial load.
158    pub specialization_ir_compile_nanos: u64,
159    /// Accumulated nanoseconds spent recomputing cached-integral planning artifacts after load.
160    pub specialization_cached_integrals_nanos: u64,
161    /// Accumulated nanoseconds spent lowering specialized runtimes after load.
162    pub specialization_lowering_nanos: u64,
163    /// Number of specialization rebuilds that reused cached lowered artifacts.
164    pub specialization_lowering_cache_hits: usize,
165    /// Number of specialization rebuilds that had to lower fresh artifacts.
166    pub specialization_lowering_cache_misses: usize,
167    /// Accumulated nanoseconds spent restoring specializations from cache.
168    pub specialization_cache_restore_nanos: u64,
169}
170impl From<ir::NormalizationPlanExplain> for NormalizationPlanExplain {
171    fn from(value: ir::NormalizationPlanExplain) -> Self {
172        Self {
173            root_dependence: value.root_dependence.into(),
174            warnings: value.warnings,
175            separable_mul_candidate_nodes: value
176                .separable_mul_candidates
177                .into_iter()
178                .map(|candidate| candidate.node_index)
179                .collect(),
180            cached_separable_nodes: value.cached_separable_nodes,
181            residual_terms: value.residual_terms,
182        }
183    }
184}
185impl From<ir::NormalizationExecutionSets> for NormalizationExecutionSetsExplain {
186    fn from(value: ir::NormalizationExecutionSets) -> Self {
187        Self {
188            cached_parameter_amplitudes: value.cached_parameter_amplitudes,
189            cached_cache_amplitudes: value.cached_cache_amplitudes,
190            residual_amplitudes: value.residual_amplitudes,
191        }
192    }
193}
194impl From<ExpressionDependence> for ir::DependenceClass {
195    fn from(value: ExpressionDependence) -> Self {
196        match value {
197            ExpressionDependence::ParameterOnly => Self::ParameterOnly,
198            ExpressionDependence::CacheOnly => Self::CacheOnly,
199            ExpressionDependence::Mixed => Self::Mixed,
200        }
201    }
202}
203
204#[cfg(feature = "mpi")]
205use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
206
207#[cfg(feature = "mpi")]
208use crate::mpi::LadduMPI;
209#[cfg(feature = "execution-context-prototype")]
210use crate::ExecutionContext;
211#[cfg(all(feature = "execution-context-prototype", feature = "rayon"))]
212use crate::ThreadPolicy;
213use crate::{
214    amplitude::{Amplitude, AmplitudeUseSite},
215    data::Dataset,
216    parameters::ParameterMap,
217    resources::{Cache, Parameters, Resources},
218    LadduError, LadduResult,
219};
220
221/// A holder struct that owns both an expression tree and the registered amplitudes.
222#[allow(missing_docs)]
223#[derive(Clone, Serialize, Deserialize, Default)]
224pub struct Expression {
225    registry: ExpressionRegistry,
226    tree: ExpressionNode,
227}
228
229#[derive(Clone, Serialize, Deserialize)]
230#[allow(missing_docs)]
231#[derive(Default)]
232pub struct ExpressionRegistry {
233    amplitudes: Vec<Box<dyn Amplitude>>,
234    amplitude_names: Vec<String>,
235    amplitude_ids: Vec<ComputeAmplitudeId>,
236    amplitude_use_sites: Vec<AmplitudeUseSite>,
237    amplitude_use_site_ids: Vec<AmplitudeUseSiteId>,
238    resources: Resources,
239}
240
241impl ExpressionRegistry {
242    fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
243        let mut resources = Resources::default();
244        let aid = amplitude.register(&mut resources)?;
245        let compute_id = next_compute_amplitude_id();
246        let use_site_id = next_amplitude_use_site_id();
247        resources.configure_amplitude_tags(std::slice::from_ref(&aid.0));
248        Ok(Self {
249            amplitudes: vec![amplitude],
250            amplitude_names: vec![aid.0.display_label()],
251            amplitude_ids: vec![compute_id],
252            amplitude_use_sites: vec![AmplitudeUseSite {
253                amplitude_index: 0,
254                tags: aid.0,
255            }],
256            amplitude_use_site_ids: vec![use_site_id],
257            resources,
258        })
259    }
260
261    fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
262        let mut resources = Resources::default();
263        let mut amplitudes = Vec::new();
264        let mut amplitude_ids = Vec::new();
265        let mut compute_id_to_index = HashMap::new();
266        let mut semantic_key_to_index = HashMap::new();
267
268        let mut left_compute_map = Vec::with_capacity(self.amplitudes.len());
269        for (amp_index, (amp, amp_id)) in
270            self.amplitudes.iter().zip(&self.amplitude_ids).enumerate()
271        {
272            let semantic_key = amp.semantic_key();
273            let mut cloned_amp = dyn_clone::clone_box(&**amp);
274            let aid = cloned_amp.register(&mut resources)?;
275            if let Some(key) = semantic_key.clone() {
276                semantic_key_to_index.insert(key, aid.1);
277            }
278            amplitudes.push(cloned_amp);
279            amplitude_ids.push(*amp_id);
280            compute_id_to_index.insert(*amp_id, aid.1);
281            left_compute_map.push(aid.1);
282            debug_assert_eq!(amp_index, aid.1);
283        }
284
285        let mut right_compute_map = Vec::with_capacity(other.amplitudes.len());
286        for (amp, amp_id) in other.amplitudes.iter().zip(&other.amplitude_ids) {
287            if let Some(existing) = compute_id_to_index.get(amp_id) {
288                right_compute_map.push(*existing);
289                continue;
290            }
291            let incoming_semantic_key = amp.semantic_key();
292            if let Some(existing) = incoming_semantic_key
293                .as_ref()
294                .and_then(|key| semantic_key_to_index.get(key))
295            {
296                right_compute_map.push(*existing);
297                continue;
298            }
299            let mut cloned_amp = dyn_clone::clone_box(&**amp);
300            let aid = cloned_amp.register(&mut resources)?;
301            if let Some(key) = incoming_semantic_key.clone() {
302                semantic_key_to_index.insert(key, aid.1);
303            }
304            amplitudes.push(cloned_amp);
305            amplitude_ids.push(*amp_id);
306            compute_id_to_index.insert(*amp_id, aid.1);
307            right_compute_map.push(aid.1);
308        }
309
310        let mut amplitude_use_sites = Vec::new();
311        let mut amplitude_use_site_ids = Vec::new();
312        let mut amplitude_names = Vec::new();
313        let mut use_site_id_to_index = HashMap::new();
314
315        let mut left_map = Vec::with_capacity(self.amplitude_use_sites.len());
316        for (use_site, use_site_id) in self
317            .amplitude_use_sites
318            .iter()
319            .zip(&self.amplitude_use_site_ids)
320        {
321            let mapped_index = left_compute_map[use_site.amplitude_index];
322            let new_index = amplitude_use_sites.len();
323            left_map.push(new_index);
324            use_site_id_to_index.insert(*use_site_id, new_index);
325            amplitude_use_site_ids.push(*use_site_id);
326            amplitude_names.push(use_site.tags.display_label());
327            amplitude_use_sites.push(AmplitudeUseSite {
328                amplitude_index: mapped_index,
329                tags: use_site.tags.clone(),
330            });
331        }
332
333        let mut right_map = Vec::with_capacity(other.amplitude_use_sites.len());
334        for (use_site, use_site_id) in other
335            .amplitude_use_sites
336            .iter()
337            .zip(&other.amplitude_use_site_ids)
338        {
339            if let Some(existing) = use_site_id_to_index.get(use_site_id) {
340                right_map.push(*existing);
341                continue;
342            }
343            let mapped_index = right_compute_map[use_site.amplitude_index];
344            let new_index = amplitude_use_sites.len();
345            right_map.push(new_index);
346            use_site_id_to_index.insert(*use_site_id, new_index);
347            amplitude_use_site_ids.push(*use_site_id);
348            amplitude_names.push(use_site.tags.display_label());
349            amplitude_use_sites.push(AmplitudeUseSite {
350                amplitude_index: mapped_index,
351                tags: use_site.tags.clone(),
352            });
353        }
354        let use_site_tags = amplitude_use_sites
355            .iter()
356            .map(|use_site| use_site.tags.clone())
357            .collect::<Vec<_>>();
358        resources.configure_amplitude_tags(&use_site_tags);
359
360        Ok((
361            Self {
362                amplitudes,
363                amplitude_names,
364                amplitude_ids,
365                amplitude_use_sites,
366                amplitude_use_site_ids,
367                resources,
368            },
369            left_map,
370            right_map,
371        ))
372    }
373}
374
375/// Expression tree used by [`Expression`].
376#[allow(missing_docs)]
377#[derive(Clone, Serialize, Deserialize, Default, Debug)]
378pub enum ExpressionNode {
379    #[default]
380    /// A expression equal to zero.
381    Zero,
382    /// A expression equal to one.
383    One,
384    /// A real-valued constant.
385    Constant(f64),
386    /// A complex-valued constant.
387    ComplexConstant(Complex64),
388    /// A registered [`Amplitude`] referenced by index.
389    Amp(usize),
390    /// The sum of two [`ExpressionNode`]s.
391    Add(Box<ExpressionNode>, Box<ExpressionNode>),
392    /// The difference of two [`ExpressionNode`]s.
393    Sub(Box<ExpressionNode>, Box<ExpressionNode>),
394    /// The product of two [`ExpressionNode`]s.
395    Mul(Box<ExpressionNode>, Box<ExpressionNode>),
396    /// The division of two [`ExpressionNode`]s.
397    Div(Box<ExpressionNode>, Box<ExpressionNode>),
398    /// The additive inverse of an [`ExpressionNode`].
399    Neg(Box<ExpressionNode>),
400    /// The real part of an [`ExpressionNode`].
401    Real(Box<ExpressionNode>),
402    /// The imaginary part of an [`ExpressionNode`].
403    Imag(Box<ExpressionNode>),
404    /// The complex conjugate of an [`ExpressionNode`].
405    Conj(Box<ExpressionNode>),
406    /// The absolute square of an [`ExpressionNode`].
407    NormSqr(Box<ExpressionNode>),
408    Sqrt(Box<ExpressionNode>),
409    Pow(Box<ExpressionNode>, Box<ExpressionNode>),
410    PowI(Box<ExpressionNode>, i32),
411    PowF(Box<ExpressionNode>, f64),
412    Exp(Box<ExpressionNode>),
413    Sin(Box<ExpressionNode>),
414    Cos(Box<ExpressionNode>),
415    Log(Box<ExpressionNode>),
416    Cis(Box<ExpressionNode>),
417}
418
419#[derive(Clone, Debug)]
420/// Standalone bytecode executor compiled directly from the semantic expression tree.
421///
422/// This is retained for direct tree-level helpers on [`ExpressionNode`] and debugging of the
423/// unfactored semantic expression shape. It is intentionally distinct from the lowered runtime:
424/// current lowering carries slot reuse, peephole rewrites, root-specific lowering, and
425/// specialized normalization helpers that would be awkward to force back into this form.
426struct ExpressionProgram {
427    ops: Vec<ExpressionOp>,
428    slot_count: usize,
429    root_slot: usize,
430}
431
432#[derive(Clone, Debug)]
433enum ExpressionOp {
434    LoadZero {
435        dst: usize,
436    },
437    LoadOne {
438        dst: usize,
439    },
440    LoadConstant {
441        dst: usize,
442        value: f64,
443    },
444    LoadComplexConstant {
445        dst: usize,
446        value: Complex64,
447    },
448    LoadAmp {
449        dst: usize,
450        amp_idx: usize,
451    },
452    Add {
453        dst: usize,
454        left: usize,
455        right: usize,
456    },
457    Sub {
458        dst: usize,
459        left: usize,
460        right: usize,
461    },
462    Mul {
463        dst: usize,
464        left: usize,
465        right: usize,
466    },
467    Div {
468        dst: usize,
469        left: usize,
470        right: usize,
471    },
472    Neg {
473        dst: usize,
474        input: usize,
475    },
476    Real {
477        dst: usize,
478        input: usize,
479    },
480    Imag {
481        dst: usize,
482        input: usize,
483    },
484    Conj {
485        dst: usize,
486        input: usize,
487    },
488    NormSqr {
489        dst: usize,
490        input: usize,
491    },
492    Sqrt {
493        dst: usize,
494        input: usize,
495    },
496    Pow {
497        dst: usize,
498        value: usize,
499        power: usize,
500    },
501    PowI {
502        dst: usize,
503        input: usize,
504        power: i32,
505    },
506    PowF {
507        dst: usize,
508        input: usize,
509        power: f64,
510    },
511    Exp {
512        dst: usize,
513        input: usize,
514    },
515    Sin {
516        dst: usize,
517        input: usize,
518    },
519    Cos {
520        dst: usize,
521        input: usize,
522    },
523    Log {
524        dst: usize,
525        input: usize,
526    },
527    Cis {
528        dst: usize,
529        input: usize,
530    },
531}
532
533#[derive(Default)]
534struct ExpressionProgramBuilder {
535    ops: Vec<ExpressionOp>,
536    next_slot: usize,
537}
538
539impl ExpressionProgramBuilder {
540    fn alloc_slot(&mut self) -> usize {
541        let slot = self.next_slot;
542        self.next_slot += 1;
543        slot
544    }
545
546    fn build(self, root: usize) -> ExpressionProgram {
547        ExpressionProgram {
548            ops: self.ops,
549            slot_count: self.next_slot,
550            root_slot: root,
551        }
552    }
553
554    fn emit(&mut self, op: ExpressionOp) {
555        self.ops.push(op);
556    }
557
558    fn compile(&mut self, node: &ExpressionNode) -> usize {
559        match node {
560            ExpressionNode::Zero => {
561                let dst = self.alloc_slot();
562                self.emit(ExpressionOp::LoadZero { dst });
563                dst
564            }
565            ExpressionNode::One => {
566                let dst = self.alloc_slot();
567                self.emit(ExpressionOp::LoadOne { dst });
568                dst
569            }
570            ExpressionNode::Constant(value) => {
571                let dst = self.alloc_slot();
572                self.emit(ExpressionOp::LoadConstant { dst, value: *value });
573                dst
574            }
575            ExpressionNode::ComplexConstant(value) => {
576                let dst = self.alloc_slot();
577                self.emit(ExpressionOp::LoadComplexConstant { dst, value: *value });
578                dst
579            }
580            ExpressionNode::Amp(idx) => {
581                let dst = self.alloc_slot();
582                self.emit(ExpressionOp::LoadAmp { dst, amp_idx: *idx });
583                dst
584            }
585            ExpressionNode::Add(a, b) => {
586                let left = self.compile(a);
587                let right = self.compile(b);
588                let dst = self.alloc_slot();
589                self.emit(ExpressionOp::Add { dst, left, right });
590                dst
591            }
592            ExpressionNode::Sub(a, b) => {
593                let left = self.compile(a);
594                let right = self.compile(b);
595                let dst = self.alloc_slot();
596                self.emit(ExpressionOp::Sub { dst, left, right });
597                dst
598            }
599            ExpressionNode::Mul(a, b) => {
600                let left = self.compile(a);
601                let right = self.compile(b);
602                let dst = self.alloc_slot();
603                self.emit(ExpressionOp::Mul { dst, left, right });
604                dst
605            }
606            ExpressionNode::Div(a, b) => {
607                let left = self.compile(a);
608                let right = self.compile(b);
609                let dst = self.alloc_slot();
610                self.emit(ExpressionOp::Div { dst, left, right });
611                dst
612            }
613            ExpressionNode::Neg(a) => {
614                let input = self.compile(a);
615                let dst = self.alloc_slot();
616                self.emit(ExpressionOp::Neg { dst, input });
617                dst
618            }
619            ExpressionNode::Real(a) => {
620                let input = self.compile(a);
621                let dst = self.alloc_slot();
622                self.emit(ExpressionOp::Real { dst, input });
623                dst
624            }
625            ExpressionNode::Imag(a) => {
626                let input = self.compile(a);
627                let dst = self.alloc_slot();
628                self.emit(ExpressionOp::Imag { dst, input });
629                dst
630            }
631            ExpressionNode::Conj(a) => {
632                let input = self.compile(a);
633                let dst = self.alloc_slot();
634                self.emit(ExpressionOp::Conj { dst, input });
635                dst
636            }
637            ExpressionNode::NormSqr(a) => {
638                let input = self.compile(a);
639                let dst = self.alloc_slot();
640                self.emit(ExpressionOp::NormSqr { dst, input });
641                dst
642            }
643            ExpressionNode::Sqrt(a) => {
644                let input = self.compile(a);
645                let dst = self.alloc_slot();
646                self.emit(ExpressionOp::Sqrt { dst, input });
647                dst
648            }
649            ExpressionNode::Pow(a, b) => {
650                let value = self.compile(a);
651                let power = self.compile(b);
652                let dst = self.alloc_slot();
653                self.emit(ExpressionOp::Pow { dst, value, power });
654                dst
655            }
656            ExpressionNode::PowI(a, power) => {
657                let input = self.compile(a);
658                let dst = self.alloc_slot();
659                self.emit(ExpressionOp::PowI {
660                    dst,
661                    input,
662                    power: *power,
663                });
664                dst
665            }
666            ExpressionNode::PowF(a, power) => {
667                let input = self.compile(a);
668                let dst = self.alloc_slot();
669                self.emit(ExpressionOp::PowF {
670                    dst,
671                    input,
672                    power: *power,
673                });
674                dst
675            }
676            ExpressionNode::Exp(a) => {
677                let input = self.compile(a);
678                let dst = self.alloc_slot();
679                self.emit(ExpressionOp::Exp { dst, input });
680                dst
681            }
682            ExpressionNode::Sin(a) => {
683                let input = self.compile(a);
684                let dst = self.alloc_slot();
685                self.emit(ExpressionOp::Sin { dst, input });
686                dst
687            }
688            ExpressionNode::Cos(a) => {
689                let input = self.compile(a);
690                let dst = self.alloc_slot();
691                self.emit(ExpressionOp::Cos { dst, input });
692                dst
693            }
694            ExpressionNode::Log(a) => {
695                let input = self.compile(a);
696                let dst = self.alloc_slot();
697                self.emit(ExpressionOp::Log { dst, input });
698                dst
699            }
700            ExpressionNode::Cis(a) => {
701                let input = self.compile(a);
702                let dst = self.alloc_slot();
703                self.emit(ExpressionOp::Cis { dst, input });
704                dst
705            }
706        }
707    }
708}
709
710impl ExpressionProgram {
711    fn from_node(node: &ExpressionNode) -> Self {
712        let mut builder = ExpressionProgramBuilder::default();
713        let root = builder.compile(node);
714        builder.build(root)
715    }
716
717    fn fill_values(&self, amplitude_values: &[Complex64], slots: &mut [Complex64]) {
718        debug_assert!(slots.len() >= self.slot_count);
719        for op in &self.ops {
720            match *op {
721                ExpressionOp::LoadZero { dst } => slots[dst] = Complex64::ZERO,
722                ExpressionOp::LoadOne { dst } => slots[dst] = Complex64::ONE,
723                ExpressionOp::LoadConstant { dst, value } => slots[dst] = Complex64::from(value),
724                ExpressionOp::LoadComplexConstant { dst, value } => slots[dst] = value,
725                ExpressionOp::LoadAmp { dst, amp_idx } => {
726                    slots[dst] = amplitude_values.get(amp_idx).copied().unwrap_or_default();
727                }
728                ExpressionOp::Add { dst, left, right } => {
729                    slots[dst] = slots[left] + slots[right];
730                }
731                ExpressionOp::Sub { dst, left, right } => {
732                    slots[dst] = slots[left] - slots[right];
733                }
734                ExpressionOp::Mul { dst, left, right } => {
735                    slots[dst] = slots[left] * slots[right];
736                }
737                ExpressionOp::Div { dst, left, right } => {
738                    slots[dst] = slots[left] / slots[right];
739                }
740                ExpressionOp::Neg { dst, input } => {
741                    slots[dst] = -slots[input];
742                }
743                ExpressionOp::Real { dst, input } => {
744                    slots[dst] = Complex64::new(slots[input].re, 0.0);
745                }
746                ExpressionOp::Imag { dst, input } => {
747                    slots[dst] = Complex64::new(slots[input].im, 0.0);
748                }
749                ExpressionOp::Conj { dst, input } => {
750                    slots[dst] = slots[input].conj();
751                }
752                ExpressionOp::NormSqr { dst, input } => {
753                    slots[dst] = Complex64::new(slots[input].norm_sqr(), 0.0);
754                }
755                ExpressionOp::Sqrt { dst, input } => {
756                    slots[dst] = slots[input].sqrt();
757                }
758                ExpressionOp::Pow { dst, value, power } => {
759                    slots[dst] = slots[value].powc(slots[power]);
760                }
761                ExpressionOp::PowI { dst, input, power } => {
762                    slots[dst] = slots[input].powi(power);
763                }
764                ExpressionOp::PowF { dst, input, power } => {
765                    slots[dst] = slots[input].powc(Complex64::new(power, 0.0));
766                }
767                ExpressionOp::Exp { dst, input } => {
768                    slots[dst] = slots[input].exp();
769                }
770                ExpressionOp::Sin { dst, input } => {
771                    slots[dst] = slots[input].sin();
772                }
773                ExpressionOp::Cos { dst, input } => {
774                    slots[dst] = slots[input].cos();
775                }
776                ExpressionOp::Log { dst, input } => {
777                    slots[dst] = slots[input].ln();
778                }
779                ExpressionOp::Cis { dst, input } => {
780                    slots[dst] = (Complex64::new(0.0, 1.0) * slots[input]).exp();
781                }
782            }
783        }
784    }
785
786    fn evaluate_into(&self, amplitude_values: &[Complex64], slots: &mut [Complex64]) -> Complex64 {
787        if self.slot_count == 0 {
788            return Complex64::ZERO;
789        }
790        self.fill_values(amplitude_values, slots);
791        slots[self.root_slot]
792    }
793
794    pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
795        if self.slot_count == 0 {
796            return Complex64::ZERO;
797        }
798        let mut slots = vec![Complex64::ZERO; self.slot_count];
799        self.evaluate_into(amplitude_values, &mut slots)
800    }
801
802    pub fn evaluate_gradient_into(
803        &self,
804        amplitude_values: &[Complex64],
805        gradient_values: &[DVector<Complex64>],
806        value_slots: &mut [Complex64],
807        gradient_slots: &mut [DVector<Complex64>],
808    ) -> DVector<Complex64> {
809        if self.slot_count == 0 {
810            let dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
811            return DVector::zeros(dim);
812        }
813        self.fill_values(amplitude_values, value_slots);
814        self.fill_gradients(gradient_values, value_slots, gradient_slots);
815        gradient_slots[self.root_slot].clone()
816    }
817
818    pub fn evaluate_gradient(
819        &self,
820        amplitude_values: &[Complex64],
821        gradient_values: &[DVector<Complex64>],
822    ) -> DVector<Complex64> {
823        let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
824        let mut value_slots = vec![Complex64::ZERO; self.slot_count];
825        let mut gradient_slots: Vec<DVector<Complex64>> = (0..self.slot_count)
826            .map(|_| DVector::zeros(grad_dim))
827            .collect();
828        self.evaluate_gradient_into(
829            amplitude_values,
830            gradient_values,
831            &mut value_slots,
832            &mut gradient_slots,
833        )
834    }
835
836    fn fill_gradients(
837        &self,
838        amplitude_gradients: &[DVector<Complex64>],
839        values: &[Complex64],
840        gradients: &mut [DVector<Complex64>],
841    ) {
842        debug_assert!(gradients.len() >= self.slot_count);
843        debug_assert!(values.len() >= self.slot_count);
844        fn borrow_dst(
845            gradients: &mut [DVector<Complex64>],
846            dst: usize,
847        ) -> (&[DVector<Complex64>], &mut DVector<Complex64>) {
848            let (before, tail) = gradients.split_at_mut(dst);
849            let (dst_ref, _) = tail.split_first_mut().expect("dst slot should exist");
850            (before, dst_ref)
851        }
852        for op in &self.ops {
853            match *op {
854                ExpressionOp::LoadZero { dst }
855                | ExpressionOp::LoadOne { dst }
856                | ExpressionOp::LoadConstant { dst, .. }
857                | ExpressionOp::LoadComplexConstant { dst, .. } => {
858                    let (_, dst_grad) = borrow_dst(gradients, dst);
859                    for item in dst_grad.iter_mut() {
860                        *item = Complex64::ZERO;
861                    }
862                }
863                ExpressionOp::LoadAmp { dst, amp_idx } => {
864                    let (_, dst_grad) = borrow_dst(gradients, dst);
865                    if let Some(source) = amplitude_gradients.get(amp_idx) {
866                        dst_grad.clone_from(source);
867                    } else {
868                        for item in dst_grad.iter_mut() {
869                            *item = Complex64::ZERO;
870                        }
871                    }
872                }
873                ExpressionOp::Add { dst, left, right } => {
874                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
875                    dst_grad.clone_from(&before_dst[left]);
876                    for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
877                    {
878                        *dst_item += *right_item;
879                    }
880                }
881                ExpressionOp::Sub { dst, left, right } => {
882                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
883                    dst_grad.clone_from(&before_dst[left]);
884                    for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
885                    {
886                        *dst_item -= *right_item;
887                    }
888                }
889                ExpressionOp::Mul { dst, left, right } => {
890                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
891                    let f_left = values[left];
892                    let f_right = values[right];
893                    dst_grad.clone_from(&before_dst[right]);
894                    for item in dst_grad.iter_mut() {
895                        *item *= f_left;
896                    }
897                    for (dst_item, left_item) in dst_grad.iter_mut().zip(before_dst[left].iter()) {
898                        *dst_item += *left_item * f_right;
899                    }
900                }
901                ExpressionOp::Div { dst, left, right } => {
902                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
903                    let f_left = values[left];
904                    let f_right = values[right];
905                    let denom = f_right * f_right;
906                    dst_grad.clone_from(&before_dst[left]);
907                    for item in dst_grad.iter_mut() {
908                        *item *= f_right;
909                    }
910                    for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
911                    {
912                        *dst_item -= *right_item * f_left;
913                    }
914                    for item in dst_grad.iter_mut() {
915                        *item /= denom;
916                    }
917                }
918                ExpressionOp::Neg { dst, input } => {
919                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
920                    dst_grad.clone_from(&before_dst[input]);
921                    for item in dst_grad.iter_mut() {
922                        *item = -*item;
923                    }
924                }
925                ExpressionOp::Real { dst, input } => {
926                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
927                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
928                    {
929                        *dst_item = Complex64::new(input_item.re, 0.0);
930                    }
931                }
932                ExpressionOp::Imag { dst, input } => {
933                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
934                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
935                    {
936                        *dst_item = Complex64::new(input_item.im, 0.0);
937                    }
938                }
939                ExpressionOp::Conj { dst, input } => {
940                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
941                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
942                    {
943                        *dst_item = input_item.conj();
944                    }
945                }
946                ExpressionOp::NormSqr { dst, input } => {
947                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
948                    let conj_value = values[input].conj();
949                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
950                    {
951                        *dst_item = Complex64::new(2.0 * (*input_item * conj_value).re, 0.0);
952                    }
953                }
954                ExpressionOp::Sqrt { dst, input } => {
955                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
956                    let factor = Complex64::new(0.5, 0.0) / values[input].sqrt();
957                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
958                    {
959                        *dst_item = *input_item * factor;
960                    }
961                }
962                ExpressionOp::Pow { dst, value, power } => {
963                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
964                    let base = values[value];
965                    let exponent = values[power];
966                    let output = values[dst];
967                    for ((dst_item, value_item), power_item) in dst_grad
968                        .iter_mut()
969                        .zip(before_dst[value].iter())
970                        .zip(before_dst[power].iter())
971                    {
972                        *dst_item =
973                            output * (*power_item * base.ln() + exponent * *value_item / base);
974                    }
975                }
976                ExpressionOp::PowI { dst, input, power } => {
977                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
978                    let factor = match power {
979                        0 => Complex64::ZERO,
980                        1 => Complex64::ONE,
981                        _ => {
982                            let base = values[input];
983                            let multiplier = Complex64::new(power as f64, 0.0);
984                            if let Some(derivative_power) = power.checked_sub(1) {
985                                multiplier * base.powi(derivative_power)
986                            } else {
987                                multiplier * base.powi(power) / base
988                            }
989                        }
990                    };
991                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
992                    {
993                        *dst_item = *input_item * factor;
994                    }
995                }
996                ExpressionOp::PowF { dst, input, power } => {
997                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
998                    let factor = if power == 0.0 {
999                        Complex64::ZERO
1000                    } else {
1001                        Complex64::new(power, 0.0)
1002                            * values[input].powc(Complex64::new(power - 1.0, 0.0))
1003                    };
1004                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1005                    {
1006                        *dst_item = *input_item * factor;
1007                    }
1008                }
1009                ExpressionOp::Exp { dst, input } => {
1010                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1011                    let output = values[dst];
1012                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1013                    {
1014                        *dst_item = *input_item * output;
1015                    }
1016                }
1017                ExpressionOp::Sin { dst, input } => {
1018                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1019                    let factor = values[input].cos();
1020                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1021                    {
1022                        *dst_item = *input_item * factor;
1023                    }
1024                }
1025                ExpressionOp::Cos { dst, input } => {
1026                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1027                    let factor = -values[input].sin();
1028                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1029                    {
1030                        *dst_item = *input_item * factor;
1031                    }
1032                }
1033                ExpressionOp::Log { dst, input } => {
1034                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1035                    let factor = Complex64::ONE / values[input];
1036                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1037                    {
1038                        *dst_item = *input_item * factor;
1039                    }
1040                }
1041                ExpressionOp::Cis { dst, input } => {
1042                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1043                    let factor = Complex64::new(0.0, 1.0) * values[dst];
1044                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1045                    {
1046                        *dst_item = *input_item * factor;
1047                    }
1048                }
1049            }
1050        }
1051    }
1052}
1053
1054impl ExpressionNode {
1055    fn remap(&self, mapping: &[usize]) -> Self {
1056        match self {
1057            Self::Amp(idx) => Self::Amp(mapping[*idx]),
1058            Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1059            Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1060            Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1061            Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1062            Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
1063            Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
1064            Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
1065            Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
1066            Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
1067            Self::Zero => Self::Zero,
1068            Self::One => Self::One,
1069            Self::Constant(v) => Self::Constant(*v),
1070            Self::ComplexConstant(v) => Self::ComplexConstant(*v),
1071            Self::Sqrt(a) => Self::Sqrt(Box::new(a.remap(mapping))),
1072            Self::Pow(a, b) => Self::Pow(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1073            Self::PowI(a, power) => Self::PowI(Box::new(a.remap(mapping)), *power),
1074            Self::PowF(a, power) => Self::PowF(Box::new(a.remap(mapping)), *power),
1075            Self::Exp(a) => Self::Exp(Box::new(a.remap(mapping))),
1076            Self::Sin(a) => Self::Sin(Box::new(a.remap(mapping))),
1077            Self::Cos(a) => Self::Cos(Box::new(a.remap(mapping))),
1078            Self::Log(a) => Self::Log(Box::new(a.remap(mapping))),
1079            Self::Cis(a) => Self::Cis(Box::new(a.remap(mapping))),
1080        }
1081    }
1082
1083    fn program(&self) -> ExpressionProgram {
1084        ExpressionProgram::from_node(self)
1085    }
1086
1087    /// Evaluate an [`ExpressionNode`] by compiling it to bytecode on the fly.
1088    pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
1089        self.program().evaluate(amplitude_values)
1090    }
1091
1092    /// Evaluate the gradient of an [`ExpressionNode`].
1093    pub fn evaluate_gradient(
1094        &self,
1095        amplitude_values: &[Complex64],
1096        gradient_values: &[DVector<Complex64>],
1097    ) -> DVector<Complex64> {
1098        self.program()
1099            .evaluate_gradient(amplitude_values, gradient_values)
1100    }
1101}
1102
1103impl From<f64> for Expression {
1104    fn from(value: f64) -> Self {
1105        if value == 0.0 {
1106            Self {
1107                registry: ExpressionRegistry::default(),
1108                tree: ExpressionNode::Zero,
1109            }
1110        } else if value == 1.0 {
1111            Self {
1112                registry: ExpressionRegistry::default(),
1113                tree: ExpressionNode::One,
1114            }
1115        } else {
1116            Self {
1117                registry: ExpressionRegistry::default(),
1118                tree: ExpressionNode::Constant(value),
1119            }
1120        }
1121    }
1122}
1123impl From<&f64> for Expression {
1124    fn from(value: &f64) -> Self {
1125        (*value).into()
1126    }
1127}
1128impl From<Complex64> for Expression {
1129    fn from(value: Complex64) -> Self {
1130        if value == Complex64::ZERO {
1131            Self {
1132                registry: ExpressionRegistry::default(),
1133                tree: ExpressionNode::Zero,
1134            }
1135        } else if value == Complex64::ONE {
1136            Self {
1137                registry: ExpressionRegistry::default(),
1138                tree: ExpressionNode::One,
1139            }
1140        } else {
1141            Self {
1142                registry: ExpressionRegistry::default(),
1143                tree: ExpressionNode::ComplexConstant(value),
1144            }
1145        }
1146    }
1147}
1148impl From<&Complex64> for Expression {
1149    fn from(value: &Complex64) -> Self {
1150        (*value).into()
1151    }
1152}
1153
1154impl Expression {
1155    /// Build an [`Expression`] from a single [`Amplitude`].
1156    pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
1157        let registry = ExpressionRegistry::singleton(amplitude)?;
1158        Ok(Self {
1159            tree: ExpressionNode::Amp(0),
1160            registry,
1161        })
1162    }
1163
1164    /// Create an expression representing zero, the additive identity.
1165    pub fn zero() -> Self {
1166        Self {
1167            registry: ExpressionRegistry::default(),
1168            tree: ExpressionNode::Zero,
1169        }
1170    }
1171
1172    /// Create an expression representing one, the multiplicative identity.
1173    pub fn one() -> Self {
1174        Self {
1175            registry: ExpressionRegistry::default(),
1176            tree: ExpressionNode::One,
1177        }
1178    }
1179
1180    fn binary_op(
1181        a: &Expression,
1182        b: &Expression,
1183        build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
1184    ) -> Expression {
1185        let (registry, left_map, right_map) = a
1186            .registry
1187            .merge(&b.registry)
1188            .expect("merging expression registries should not fail");
1189        let left_tree = a.tree.remap(&left_map);
1190        let right_tree = b.tree.remap(&right_map);
1191        Expression {
1192            registry,
1193            tree: build(Box::new(left_tree), Box::new(right_tree)),
1194        }
1195    }
1196
1197    fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
1198        Expression {
1199            registry: a.registry.clone(),
1200            tree: build(Box::new(a.tree.clone())),
1201        }
1202    }
1203
1204    /// Get the parameters used by this expression.
1205    pub fn parameters(&self) -> ParameterMap {
1206        self.registry.resources.parameters()
1207    }
1208
1209    /// Number of free parameters.
1210    pub fn n_free(&self) -> usize {
1211        self.registry.resources.n_free_parameters()
1212    }
1213
1214    /// Number of fixed parameters.
1215    pub fn n_fixed(&self) -> usize {
1216        self.registry.resources.n_fixed_parameters()
1217    }
1218
1219    /// Total number of parameters.
1220    pub fn n_parameters(&self) -> usize {
1221        self.registry.resources.n_parameters()
1222    }
1223
1224    /// Returns a tree-like diagnostic snapshot of this expression's compiled form.
1225    ///
1226    /// This compiles the expression on each call with every registered amplitude active. Use
1227    /// [`Evaluator::compiled_expression`] when you need the compiled form for a loaded evaluator's
1228    /// current active-amplitude mask.
1229    pub fn compiled_expression(&self) -> CompiledExpression {
1230        let active_amplitudes = vec![true; self.registry.amplitude_use_sites.len()];
1231        let amplitude_dependencies = self
1232            .registry
1233            .amplitude_use_sites
1234            .iter()
1235            .map(|use_site| {
1236                ir::DependenceClass::from(
1237                    self.registry.amplitudes[use_site.amplitude_index].dependence_hint(),
1238                )
1239            })
1240            .collect::<Vec<_>>();
1241        let amplitude_realness = self
1242            .registry
1243            .amplitude_use_sites
1244            .iter()
1245            .map(|use_site| self.registry.amplitudes[use_site.amplitude_index].real_valued_hint())
1246            .collect::<Vec<_>>();
1247        let expression_ir = ir::compile_expression_ir_with_real_hints(
1248            &self.tree,
1249            &active_amplitudes,
1250            &amplitude_dependencies,
1251            &amplitude_realness,
1252        );
1253        CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
1254    }
1255
1256    /// Fix a parameter used by this expression's evaluator resources.
1257    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
1258        self.registry.resources.fix_parameter(name, value)
1259    }
1260
1261    /// Mark a parameter used by this expression's evaluator resources as free.
1262    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
1263        self.registry.resources.free_parameter(name)
1264    }
1265
1266    /// Return a new [`Expression`] with a single parameter renamed.
1267    pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
1268        self.registry.resources.rename_parameter(old, new)
1269    }
1270
1271    /// Return a new [`Expression`] with several parameters renamed.
1272    pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
1273        self.registry.resources.rename_parameters(mapping)
1274    }
1275
1276    /// Load an [`Expression`] against a dataset, binding amplitudes and reserving caches.
1277    pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
1278        let mut resources = self.registry.resources.clone();
1279        let metadata = dataset.metadata();
1280        resources.reserve_cache(dataset.n_events_local());
1281        resources.refresh_active_indices();
1282        let parameter_map = resources.parameter_map.clone();
1283        let mut amplitudes: Vec<Box<dyn Amplitude>> = self
1284            .registry
1285            .amplitudes
1286            .iter()
1287            .map(|amp| dyn_clone::clone_box(&**amp))
1288            .collect();
1289        {
1290            for amplitude in amplitudes.iter_mut() {
1291                amplitude.bind(metadata)?;
1292                amplitude.precompute_all(dataset, &mut resources);
1293            }
1294        }
1295        let ir_compile_start = Instant::now();
1296        let expression_ir = {
1297            let mut active_amplitudes = vec![false; self.registry.amplitude_use_sites.len()];
1298            for &index in resources.active_indices() {
1299                active_amplitudes[index] = true;
1300            }
1301            let amplitude_dependencies = self
1302                .registry
1303                .amplitude_use_sites
1304                .iter()
1305                .map(|use_site| {
1306                    ir::DependenceClass::from(
1307                        amplitudes[use_site.amplitude_index].dependence_hint(),
1308                    )
1309                })
1310                .collect::<Vec<_>>();
1311            let amplitude_realness = self
1312                .registry
1313                .amplitude_use_sites
1314                .iter()
1315                .map(|use_site| amplitudes[use_site.amplitude_index].real_valued_hint())
1316                .collect::<Vec<_>>();
1317            ir::compile_expression_ir_with_real_hints(
1318                &self.tree,
1319                &active_amplitudes,
1320                &amplitude_dependencies,
1321                &amplitude_realness,
1322            )
1323        };
1324        let initial_ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
1325        let cached_integrals_start = Instant::now();
1326        let cached_integrals = Evaluator::precompute_cached_integrals_at_load(
1327            &expression_ir,
1328            &amplitudes,
1329            &self.registry.amplitude_use_sites,
1330            &resources,
1331            dataset,
1332            parameter_map.free().len(),
1333        )?;
1334        let initial_cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
1335        let lowering_start = Instant::now();
1336        let lowered_artifacts = Arc::new(Evaluator::lower_expression_runtime_artifacts(
1337            &expression_ir,
1338            &cached_integrals,
1339        )?);
1340        let initial_lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
1341        let execution_sets = expression_ir.normalization_execution_sets().clone();
1342        let cached_integral_key =
1343            Evaluator::cached_integral_cache_key(resources.active.clone(), dataset);
1344        let cached_integral_state = Arc::new(CachedIntegralCacheState {
1345            key: cached_integral_key.clone(),
1346            expression_ir,
1347            values: cached_integrals,
1348            execution_sets,
1349        });
1350        let specialization_state = ExpressionSpecializationState {
1351            cached_integrals: cached_integral_state.clone(),
1352            lowered_artifacts: lowered_artifacts.clone(),
1353        };
1354        let specialization_cache = HashMap::from([(cached_integral_key, specialization_state)]);
1355        let lowered_artifact_cache =
1356            HashMap::from([(resources.active.clone(), lowered_artifacts.clone())]);
1357        Ok(Evaluator {
1358            amplitudes,
1359            amplitude_use_sites: self.registry.amplitude_use_sites.clone(),
1360            resources: Arc::new(RwLock::new(resources)),
1361            dataset: dataset.clone(),
1362            expression: self.tree.clone(),
1363            ir_planning: ExpressionIrPlanningState {
1364                cached_integrals: Arc::new(RwLock::new(Some(cached_integral_state))),
1365                specialization_cache: Arc::new(RwLock::new(specialization_cache)),
1366                specialization_metrics: Arc::new(RwLock::new(ExpressionSpecializationMetrics {
1367                    cache_hits: 0,
1368                    cache_misses: 1,
1369                })),
1370                lowered_artifact_cache: Arc::new(RwLock::new(lowered_artifact_cache)),
1371                active_lowered_artifacts: Arc::new(RwLock::new(Some(lowered_artifacts.clone()))),
1372                specialization_status: Arc::new(RwLock::new(Some(
1373                    ExpressionSpecializationStatus {
1374                        origin: ExpressionSpecializationOrigin::InitialLoad,
1375                    },
1376                ))),
1377                compile_metrics: Arc::new(RwLock::new(ExpressionCompileMetrics {
1378                    initial_ir_compile_nanos,
1379                    initial_cached_integrals_nanos,
1380                    initial_lowering_nanos,
1381                    specialization_lowering_cache_misses: 1,
1382                    ..Default::default()
1383                })),
1384            },
1385            registry: self.registry.clone(),
1386        })
1387    }
1388
1389    /// Takes the real part of the given [`Expression`].
1390    pub fn real(&self) -> Self {
1391        Self::unary_op(self, ExpressionNode::Real)
1392    }
1393    /// Takes the imaginary part of the given [`Expression`].
1394    pub fn imag(&self) -> Self {
1395        Self::unary_op(self, ExpressionNode::Imag)
1396    }
1397    /// Takes the complex conjugate of the given [`Expression`].
1398    pub fn conj(&self) -> Self {
1399        Self::unary_op(self, ExpressionNode::Conj)
1400    }
1401    /// Takes the absolute square of the given [`Expression`].
1402    pub fn norm_sqr(&self) -> Self {
1403        Self::unary_op(self, ExpressionNode::NormSqr)
1404    }
1405    /// Takes the square root of the given [`Expression`].
1406    pub fn sqrt(&self) -> Self {
1407        Self::unary_op(self, ExpressionNode::Sqrt)
1408    }
1409    /// Raises the given [`Expression`] to an expression-valued power.
1410    pub fn pow(&self, power: &Expression) -> Self {
1411        Self::binary_op(self, power, ExpressionNode::Pow)
1412    }
1413    /// Raises the given [`Expression`] to an integer power.
1414    pub fn powi(&self, power: i32) -> Self {
1415        Self::unary_op(self, |input| ExpressionNode::PowI(input, power))
1416    }
1417    /// Raises the given [`Expression`] to a real-valued power.
1418    pub fn powf(&self, power: f64) -> Self {
1419        Self::unary_op(self, |input| ExpressionNode::PowF(input, power))
1420    }
1421    /// Takes the exponential of the given [`Expression`].
1422    pub fn exp(&self) -> Self {
1423        Self::unary_op(self, ExpressionNode::Exp)
1424    }
1425    /// Takes the sine of the given [`Expression`].
1426    pub fn sin(&self) -> Self {
1427        Self::unary_op(self, ExpressionNode::Sin)
1428    }
1429    /// Takes the cosine of the given [`Expression`].
1430    pub fn cos(&self) -> Self {
1431        Self::unary_op(self, ExpressionNode::Cos)
1432    }
1433    /// Takes the natural logarithm of the given [`Expression`].
1434    pub fn log(&self) -> Self {
1435        Self::unary_op(self, ExpressionNode::Log)
1436    }
1437    /// Takes the complex phase factor exp(i * expression).
1438    pub fn cis(&self) -> Self {
1439        Self::unary_op(self, ExpressionNode::Cis)
1440    }
1441
1442    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
1443    fn write_tree(
1444        &self,
1445        t: &ExpressionNode,
1446        f: &mut std::fmt::Formatter<'_>,
1447        parent_prefix: &str,
1448        immediate_prefix: &str,
1449        parent_suffix: &str,
1450    ) -> std::fmt::Result {
1451        let display_string = match t {
1452            ExpressionNode::Amp(idx) => {
1453                let name = self
1454                    .registry
1455                    .amplitude_names
1456                    .get(*idx)
1457                    .cloned()
1458                    .unwrap_or_else(|| "<unregistered>".to_string());
1459                format!("{name}(id={idx})")
1460            }
1461            ExpressionNode::Add(_, _) => "+".to_string(),
1462            ExpressionNode::Sub(_, _) => "-".to_string(),
1463            ExpressionNode::Mul(_, _) => "×".to_string(),
1464            ExpressionNode::Div(_, _) => "÷".to_string(),
1465            ExpressionNode::Neg(_) => "-".to_string(),
1466            ExpressionNode::Real(_) => "Re".to_string(),
1467            ExpressionNode::Imag(_) => "Im".to_string(),
1468            ExpressionNode::Conj(_) => "*".to_string(),
1469            ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
1470            ExpressionNode::Zero => "0 (exact)".to_string(),
1471            ExpressionNode::One => "1 (exact)".to_string(),
1472            ExpressionNode::Constant(v) => v.to_string(),
1473            ExpressionNode::ComplexConstant(v) => v.to_string(),
1474            ExpressionNode::Sqrt(_) => "Sqrt".to_string(),
1475            ExpressionNode::Pow(_, _) => "Pow".to_string(),
1476            ExpressionNode::PowI(_, power) => format!("PowI({power})"),
1477            ExpressionNode::PowF(_, power) => format!("PowF({power})"),
1478            ExpressionNode::Exp(_) => "Exp".to_string(),
1479            ExpressionNode::Sin(_) => "Sin".to_string(),
1480            ExpressionNode::Cos(_) => "Cos".to_string(),
1481            ExpressionNode::Log(_) => "Log".to_string(),
1482            ExpressionNode::Cis(_) => "Cis".to_string(),
1483        };
1484        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
1485        match t {
1486            ExpressionNode::Amp(_)
1487            | ExpressionNode::Zero
1488            | ExpressionNode::One
1489            | ExpressionNode::Constant(_)
1490            | ExpressionNode::ComplexConstant(_) => {}
1491            ExpressionNode::Add(a, b)
1492            | ExpressionNode::Sub(a, b)
1493            | ExpressionNode::Mul(a, b)
1494            | ExpressionNode::Div(a, b)
1495            | ExpressionNode::Pow(a, b) => {
1496                let terms = [a, b];
1497                let mut it = terms.iter().peekable();
1498                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1499                while let Some(child) = it.next() {
1500                    match it.peek() {
1501                        Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│  "),
1502                        None => self.write_tree(child, f, &child_prefix, "└─ ", "   "),
1503                    }?;
1504                }
1505            }
1506            ExpressionNode::Neg(a)
1507            | ExpressionNode::Real(a)
1508            | ExpressionNode::Imag(a)
1509            | ExpressionNode::Conj(a)
1510            | ExpressionNode::NormSqr(a)
1511            | ExpressionNode::Sqrt(a)
1512            | ExpressionNode::PowI(a, _)
1513            | ExpressionNode::PowF(a, _)
1514            | ExpressionNode::Exp(a)
1515            | ExpressionNode::Sin(a)
1516            | ExpressionNode::Cos(a)
1517            | ExpressionNode::Log(a)
1518            | ExpressionNode::Cis(a) => {
1519                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1520                self.write_tree(a, f, &child_prefix, "└─ ", "   ")?;
1521            }
1522        }
1523        Ok(())
1524    }
1525}
1526
1527impl Debug for Expression {
1528    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1529        self.write_tree(&self.tree, f, "", "", "")
1530    }
1531}
1532
1533impl Display for Expression {
1534    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1535        self.write_tree(&self.tree, f, "", "", "")
1536    }
1537}
1538
1539#[rustfmt::skip]
1540impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
1541    Expression::binary_op(a, b, ExpressionNode::Add)
1542});
1543#[rustfmt::skip]
1544impl_op_ex!(+ |a: &Expression, b: &f64| -> Expression {
1545    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
1546});
1547#[rustfmt::skip]
1548impl_op_ex!(+ |a: &f64, b: &Expression| -> Expression {
1549    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
1550});
1551#[rustfmt::skip]
1552impl_op_ex!(+ |a: &Expression, b: &Complex64| -> Expression {
1553    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
1554});
1555#[rustfmt::skip]
1556impl_op_ex!(+ |a: &Complex64, b: &Expression| -> Expression {
1557    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
1558});
1559
1560#[rustfmt::skip]
1561impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
1562    Expression::binary_op(a, b, ExpressionNode::Sub)
1563});
1564#[rustfmt::skip]
1565impl_op_ex!(- |a: &Expression, b: &f64| -> Expression {
1566    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
1567});
1568#[rustfmt::skip]
1569impl_op_ex!(- |a: &f64, b: &Expression| -> Expression {
1570    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
1571});
1572#[rustfmt::skip]
1573impl_op_ex!(- |a: &Expression, b: &Complex64| -> Expression {
1574    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
1575});
1576#[rustfmt::skip]
1577impl_op_ex!(- |a: &Complex64, b: &Expression| -> Expression {
1578    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
1579});
1580
1581#[rustfmt::skip]
1582impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
1583    Expression::binary_op(a, b, ExpressionNode::Mul)
1584});
1585#[rustfmt::skip]
1586impl_op_ex!(* |a: &Expression, b: &f64| -> Expression {
1587    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
1588});
1589#[rustfmt::skip]
1590impl_op_ex!(* |a: &f64, b: &Expression| -> Expression {
1591    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
1592});
1593#[rustfmt::skip]
1594impl_op_ex!(* |a: &Expression, b: &Complex64| -> Expression {
1595    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
1596});
1597#[rustfmt::skip]
1598impl_op_ex!(* |a: &Complex64, b: &Expression| -> Expression {
1599    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
1600});
1601
1602#[rustfmt::skip]
1603impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
1604    Expression::binary_op(a, b, ExpressionNode::Div)
1605});
1606#[rustfmt::skip]
1607impl_op_ex!(/ |a: &Expression, b: &f64| -> Expression {
1608    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
1609});
1610#[rustfmt::skip]
1611impl_op_ex!(/ |a: &f64, b: &Expression| -> Expression {
1612    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
1613});
1614#[rustfmt::skip]
1615impl_op_ex!(/ |a: &Expression, b: &Complex64| -> Expression {
1616    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
1617});
1618#[rustfmt::skip]
1619impl_op_ex!(/ |a: &Complex64, b: &Expression| -> Expression {
1620    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
1621});
1622
1623#[rustfmt::skip]
1624impl_op_ex!(- |a: &Expression| -> Expression {
1625    Expression::unary_op(a, ExpressionNode::Neg)
1626});
1627// NOTE: no need to add an impl for negating f64 or complex!
1628
1629#[derive(Clone, Debug)]
1630#[doc(hidden)]
1631pub struct ExpressionValueProgramSnapshot {
1632    lowered_program: lowered::LoweredProgram,
1633}
1634
1635#[derive(Clone, Debug, PartialEq)]
1636/// A node in a compiled expression diagnostic snapshot.
1637pub enum CompiledExpressionNode {
1638    /// A complex constant.
1639    Constant(Complex64),
1640    /// A registered amplitude use-site by index and display label.
1641    Amplitude {
1642        /// The amplitude index used by the compiled evaluator.
1643        index: usize,
1644        /// The registered amplitude display label.
1645        name: String,
1646    },
1647    /// A unary operation and its input node.
1648    Unary {
1649        /// The display label for the operation.
1650        op: String,
1651        /// The input node index.
1652        input: usize,
1653    },
1654    /// A binary operation and its input nodes.
1655    Binary {
1656        /// The display label for the operation.
1657        op: String,
1658        /// The left input node index.
1659        left: usize,
1660        /// The right input node index.
1661        right: usize,
1662    },
1663}
1664
1665#[derive(Clone, Debug, PartialEq)]
1666/// Tree-like diagnostic view of the compiled expression DAG.
1667///
1668/// The compiled expression is a directed acyclic graph because common subexpressions can be
1669/// deduplicated during compilation. The display format expands the graph from the root once and
1670/// marks later visits to the same node with `(ref)`.
1671pub struct CompiledExpression {
1672    nodes: Vec<CompiledExpressionNode>,
1673    root: usize,
1674}
1675
1676impl CompiledExpression {
1677    fn from_ir(ir: &ir::ExpressionIR, amplitude_names: &[String]) -> Self {
1678        let nodes = ir
1679            .nodes()
1680            .iter()
1681            .map(|node| match node {
1682                ir::IrNode::Constant(value) => CompiledExpressionNode::Constant(*value),
1683                ir::IrNode::Amp(index) => CompiledExpressionNode::Amplitude {
1684                    index: *index,
1685                    name: amplitude_names
1686                        .get(*index)
1687                        .cloned()
1688                        .unwrap_or_else(|| "<unregistered>".to_string()),
1689                },
1690                ir::IrNode::Unary { op, input } => CompiledExpressionNode::Unary {
1691                    op: compiled_unary_op_label(*op),
1692                    input: *input,
1693                },
1694                ir::IrNode::Binary { op, left, right } => CompiledExpressionNode::Binary {
1695                    op: compiled_binary_op_label(*op),
1696                    left: *left,
1697                    right: *right,
1698                },
1699            })
1700            .collect();
1701        Self {
1702            nodes,
1703            root: ir.root(),
1704        }
1705    }
1706
1707    /// Returns the compiled expression node list in evaluator execution order.
1708    pub fn nodes(&self) -> &[CompiledExpressionNode] {
1709        &self.nodes
1710    }
1711
1712    /// Returns the root node index.
1713    pub fn root(&self) -> usize {
1714        self.root
1715    }
1716
1717    fn node_label(&self, index: usize) -> String {
1718        let Some(node) = self.nodes.get(index) else {
1719            return format!("#{index} <missing>");
1720        };
1721        let label = match node {
1722            CompiledExpressionNode::Constant(value) => format!("const {value}"),
1723            CompiledExpressionNode::Amplitude { index, name } => {
1724                format!("{name}(id={index})")
1725            }
1726            CompiledExpressionNode::Unary { op, .. }
1727            | CompiledExpressionNode::Binary { op, .. } => op.clone(),
1728        };
1729        format!("#{index} {label}")
1730    }
1731
1732    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
1733    fn write_tree(
1734        &self,
1735        index: usize,
1736        f: &mut std::fmt::Formatter<'_>,
1737        parent_prefix: &str,
1738        immediate_prefix: &str,
1739        parent_suffix: &str,
1740        expanded: &mut [bool],
1741    ) -> std::fmt::Result {
1742        let already_expanded = expanded.get(index).copied().unwrap_or(false);
1743        if let Some(slot) = expanded.get_mut(index) {
1744            *slot = true;
1745        }
1746        let ref_suffix = if already_expanded { " (ref)" } else { "" };
1747        writeln!(
1748            f,
1749            "{}{}{}{}",
1750            parent_prefix,
1751            immediate_prefix,
1752            self.node_label(index),
1753            ref_suffix
1754        )?;
1755        if already_expanded {
1756            return Ok(());
1757        }
1758        let Some(node) = self.nodes.get(index) else {
1759            return Ok(());
1760        };
1761        match node {
1762            CompiledExpressionNode::Constant(_) | CompiledExpressionNode::Amplitude { .. } => {}
1763            CompiledExpressionNode::Unary { input, .. } => {
1764                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1765                self.write_tree(*input, f, &child_prefix, "└─ ", "   ", expanded)?;
1766            }
1767            CompiledExpressionNode::Binary { left, right, .. } => {
1768                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1769                self.write_tree(*left, f, &child_prefix, "├─ ", "│  ", expanded)?;
1770                self.write_tree(*right, f, &child_prefix, "└─ ", "   ", expanded)?;
1771            }
1772        }
1773        Ok(())
1774    }
1775}
1776
1777impl Display for CompiledExpression {
1778    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1779        if self.nodes.is_empty() {
1780            return writeln!(f, "<empty>");
1781        }
1782        let mut expanded = vec![false; self.nodes.len()];
1783        self.write_tree(self.root, f, "", "", "", &mut expanded)
1784    }
1785}
1786
1787fn compiled_unary_op_label(op: ir::IrUnaryOp) -> String {
1788    match op {
1789        ir::IrUnaryOp::Neg => "-".to_string(),
1790        ir::IrUnaryOp::Real => "Re".to_string(),
1791        ir::IrUnaryOp::Imag => "Im".to_string(),
1792        ir::IrUnaryOp::Conj => "*".to_string(),
1793        ir::IrUnaryOp::NormSqr => "NormSqr".to_string(),
1794        ir::IrUnaryOp::Sqrt => "Sqrt".to_string(),
1795        ir::IrUnaryOp::PowI(power) => format!("PowI({power})"),
1796        ir::IrUnaryOp::PowF(bits) => format!("PowF({})", f64::from_bits(bits)),
1797        ir::IrUnaryOp::Exp => "Exp".to_string(),
1798        ir::IrUnaryOp::Sin => "Sin".to_string(),
1799        ir::IrUnaryOp::Cos => "Cos".to_string(),
1800        ir::IrUnaryOp::Log => "Log".to_string(),
1801        ir::IrUnaryOp::Cis => "Cis".to_string(),
1802    }
1803}
1804
1805fn compiled_binary_op_label(op: ir::IrBinaryOp) -> String {
1806    match op {
1807        ir::IrBinaryOp::Add => "+".to_string(),
1808        ir::IrBinaryOp::Sub => "-".to_string(),
1809        ir::IrBinaryOp::Mul => "×".to_string(),
1810        ir::IrBinaryOp::Div => "÷".to_string(),
1811        ir::IrBinaryOp::Pow => "Pow".to_string(),
1812    }
1813}
1814#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1815/// Origin of the currently installed expression specialization.
1816pub enum ExpressionSpecializationOrigin {
1817    /// The specialization installed during evaluator construction.
1818    InitialLoad,
1819    /// The specialization was rebuilt because no cached entry matched the current state.
1820    CacheMissRebuild,
1821    /// The specialization was restored from an existing cache entry.
1822    CacheHitRestore,
1823}
1824#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1825/// Current specialization status for the evaluator runtime.
1826pub struct ExpressionSpecializationStatus {
1827    /// How the active specialization was obtained most recently.
1828    pub origin: ExpressionSpecializationOrigin,
1829}
1830#[derive(Clone, Debug, PartialEq, Eq)]
1831/// Diagnostic snapshot of the active runtime specialization state.
1832pub struct ExpressionRuntimeDiagnostics {
1833    /// Whether IR planning state is present for the evaluator.
1834    pub ir_planning_enabled: bool,
1835    /// Whether a lowered value-only program is available for the active specialization.
1836    pub lowered_value_program_present: bool,
1837    /// Whether a lowered gradient-only program is available for the active specialization.
1838    pub lowered_gradient_program_present: bool,
1839    /// Whether a lowered fused value+gradient program is available for the active specialization.
1840    pub lowered_value_gradient_program_present: bool,
1841    /// Number of cached parameter-factor descriptors in the active specialization.
1842    pub cached_parameter_factor_count: usize,
1843    /// Number of cached parameter factors with lowered runtimes available.
1844    pub lowered_cached_parameter_factor_count: usize,
1845    /// Whether a lowered residual normalization runtime is available.
1846    pub residual_runtime_present: bool,
1847    /// Number of cached specialization entries currently retained.
1848    pub specialization_cache_entries: usize,
1849    /// Number of cached lowered-artifact entries currently retained.
1850    pub lowered_artifact_cache_entries: usize,
1851    /// Origin of the currently installed specialization, when available.
1852    pub specialization_status: Option<ExpressionSpecializationStatus>,
1853}
1854#[derive(Clone)]
1855/// IR-planning state derived from the semantic expression tree plus the current active mask.
1856///
1857/// Invariants:
1858/// - `expression_ir` is never a source of truth; it is always derived from `Evaluator::expression`.
1859/// - `cached_integrals` are specialization-dependent and must be treated as invalid once the
1860///   active mask or dataset identity changes.
1861struct ExpressionIrPlanningState {
1862    cached_integrals: Arc<RwLock<Option<Arc<CachedIntegralCacheState>>>>,
1863    specialization_cache:
1864        Arc<RwLock<HashMap<CachedIntegralCacheKey, ExpressionSpecializationState>>>,
1865    specialization_metrics: Arc<RwLock<ExpressionSpecializationMetrics>>,
1866    lowered_artifact_cache: Arc<RwLock<HashMap<Vec<bool>, Arc<LoweredArtifactCacheState>>>>,
1867    active_lowered_artifacts: Arc<RwLock<Option<Arc<LoweredArtifactCacheState>>>>,
1868    specialization_status: Arc<RwLock<Option<ExpressionSpecializationStatus>>>,
1869    compile_metrics: Arc<RwLock<ExpressionCompileMetrics>>,
1870}
1871/// Evaluator for [`Expression`] that mirrors the existing evaluator behavior.
1872#[allow(missing_docs)]
1873#[derive(Clone)]
1874pub struct Evaluator {
1875    pub amplitudes: Vec<Box<dyn Amplitude>>,
1876    amplitude_use_sites: Vec<AmplitudeUseSite>,
1877    pub resources: Arc<RwLock<Resources>>,
1878    pub dataset: Arc<Dataset>,
1879    pub expression: ExpressionNode,
1880    ir_planning: ExpressionIrPlanningState,
1881    registry: ExpressionRegistry,
1882}
1883
1884#[allow(missing_docs)]
1885impl Evaluator {
1886    /// Internal benchmarking/debug counters for specialization cache reuse.
1887    pub fn expression_specialization_metrics(&self) -> ExpressionSpecializationMetrics {
1888        *self.ir_planning.specialization_metrics.read()
1889    }
1890    /// Reset specialization cache counters while leaving cached specializations intact.
1891    pub fn reset_expression_specialization_metrics(&self) {
1892        *self.ir_planning.specialization_metrics.write() =
1893            ExpressionSpecializationMetrics::default();
1894    }
1895    /// Internal benchmarking/debug metrics for staged IR compile and lowering costs.
1896    pub fn expression_compile_metrics(&self) -> ExpressionCompileMetrics {
1897        *self.ir_planning.compile_metrics.read()
1898    }
1899    /// Internal diagnostics surface for active runtime specialization state.
1900    pub fn expression_runtime_diagnostics(&self) -> ExpressionRuntimeDiagnostics {
1901        let active_artifacts = self.active_lowered_artifacts();
1902        let cached_parameter_factor_count = self
1903            .ir_planning
1904            .cached_integrals
1905            .read()
1906            .as_ref()
1907            .map(|state| state.values.len())
1908            .unwrap_or(0);
1909        let lowered_cached_parameter_factor_count = active_artifacts
1910            .as_ref()
1911            .map(|artifacts| {
1912                artifacts
1913                    .lowered_parameter_factors
1914                    .iter()
1915                    .filter(|factor| factor.is_some())
1916                    .count()
1917            })
1918            .unwrap_or(0);
1919        let residual_runtime_present = active_artifacts
1920            .as_ref()
1921            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
1922            .is_some();
1923        ExpressionRuntimeDiagnostics {
1924            ir_planning_enabled: true,
1925            lowered_value_program_present: true,
1926            lowered_gradient_program_present: true,
1927            lowered_value_gradient_program_present: true,
1928            cached_parameter_factor_count,
1929            lowered_cached_parameter_factor_count,
1930            residual_runtime_present,
1931            specialization_cache_entries: self.ir_planning.specialization_cache.read().len(),
1932            lowered_artifact_cache_entries: self.ir_planning.lowered_artifact_cache.read().len(),
1933            specialization_status: *self.ir_planning.specialization_status.read(),
1934        }
1935    }
1936    /// Reset post-load compile/lowering counters while preserving initial-load metrics.
1937    pub fn reset_expression_compile_metrics(&self) {
1938        let mut metrics = self.ir_planning.compile_metrics.write();
1939        metrics.specialization_cache_hits = 0;
1940        metrics.specialization_cache_misses = 0;
1941        metrics.specialization_ir_compile_nanos = 0;
1942        metrics.specialization_cached_integrals_nanos = 0;
1943        metrics.specialization_lowering_nanos = 0;
1944        metrics.specialization_lowering_cache_hits = 0;
1945        metrics.specialization_lowering_cache_misses = 0;
1946        metrics.specialization_cache_restore_nanos = 0;
1947    }
1948    #[cfg(test)]
1949    fn expression_ir(&self) -> ir::ExpressionIR {
1950        self.ir_planning
1951            .cached_integrals
1952            .read()
1953            .as_ref()
1954            .map(|state| state.expression_ir.clone())
1955            .expect("cached integral state should exist for evaluator IR access")
1956    }
1957    fn lowered_runtime(&self) -> lowered::LoweredExpressionRuntime {
1958        self.active_lowered_artifacts()
1959            .expect("active lowered artifacts should exist for the current specialization")
1960            .lowered_runtime
1961            .clone()
1962    }
1963    fn active_lowered_artifacts(&self) -> Option<Arc<LoweredArtifactCacheState>> {
1964        self.ir_planning.active_lowered_artifacts.read().clone()
1965    }
1966    fn lowered_runtime_slot_count(&self) -> usize {
1967        let runtime = self.lowered_runtime();
1968        [
1969            runtime.value_program().scratch_slots(),
1970            runtime.gradient_program().scratch_slots(),
1971            runtime.value_gradient_program().scratch_slots(),
1972        ]
1973        .into_iter()
1974        .max()
1975        .unwrap_or(0)
1976    }
1977    fn lowered_value_runtime_slot_count(&self) -> usize {
1978        self.lowered_runtime().value_program().scratch_slots()
1979    }
1980
1981    #[doc(hidden)]
1982    pub fn expression_value_program_snapshot(&self) -> ExpressionValueProgramSnapshot {
1983        ExpressionValueProgramSnapshot {
1984            lowered_program: self.lowered_runtime().value_program().clone(),
1985        }
1986    }
1987
1988    #[doc(hidden)]
1989    pub fn expression_value_program_snapshot_for_active_mask(
1990        &self,
1991        active_mask: &[bool],
1992    ) -> LadduResult<ExpressionValueProgramSnapshot> {
1993        let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1994        let lowered_program =
1995            lowered::LoweredProgram::from_ir_value_only(&expression_ir).map_err(|error| {
1996                LadduError::Custom(format!(
1997                    "Failed to lower value-only active-mask runtime: {error:?}"
1998                ))
1999            })?;
2000        Ok(ExpressionValueProgramSnapshot { lowered_program })
2001    }
2002
2003    #[doc(hidden)]
2004    pub fn expression_value_program_snapshot_slot_count(
2005        &self,
2006        snapshot: &ExpressionValueProgramSnapshot,
2007    ) -> usize {
2008        let _ = self;
2009        snapshot.lowered_program.scratch_slots()
2010    }
2011
2012    /// Returns a tree-like diagnostic snapshot of the compiled expression for the evaluator's
2013    /// current active-amplitude mask.
2014    pub fn compiled_expression(&self) -> CompiledExpression {
2015        let expression_ir = self.compile_expression_ir_for_active_mask(&self.active_mask());
2016        CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
2017    }
2018
2019    /// Returns the expression represented by this evaluator.
2020    pub fn expression(&self) -> Expression {
2021        Expression {
2022            tree: self.expression.clone(),
2023            registry: self.registry.clone(),
2024        }
2025    }
2026    fn lowered_gradient_runtime_slot_count(&self) -> usize {
2027        self.lowered_runtime().gradient_program().scratch_slots()
2028    }
2029    fn lowered_value_gradient_runtime_slot_count(&self) -> usize {
2030        self.lowered_runtime()
2031            .value_gradient_program()
2032            .scratch_slots()
2033    }
2034
2035    fn expression_value_slot_count(&self) -> usize {
2036        self.lowered_value_runtime_slot_count()
2037    }
2038    fn expression_gradient_slot_count(&self) -> usize {
2039        self.lowered_gradient_runtime_slot_count()
2040    }
2041    fn expression_value_gradient_slot_count(&self) -> usize {
2042        self.lowered_value_gradient_runtime_slot_count()
2043    }
2044
2045    #[doc(hidden)]
2046    pub fn expression_value_gradient_slot_count_public(&self) -> usize {
2047        self.expression_value_gradient_slot_count()
2048    }
2049    #[cfg(test)]
2050    fn specialization_cache_len(&self) -> usize {
2051        self.ir_planning.specialization_cache.read().len()
2052    }
2053    #[cfg(test)]
2054    fn lowered_artifact_cache_len(&self) -> usize {
2055        self.ir_planning.lowered_artifact_cache.read().len()
2056    }
2057    fn install_expression_specialization(&self, specialization: &ExpressionSpecializationState) {
2058        debug_assert!(Self::lowered_artifact_signature_matches(
2059            &specialization.lowered_artifacts,
2060            &specialization.cached_integrals.values,
2061        ));
2062        *self.ir_planning.cached_integrals.write() = Some(specialization.cached_integrals.clone());
2063        *self.ir_planning.active_lowered_artifacts.write() =
2064            Some(specialization.lowered_artifacts.clone());
2065        debug_assert_eq!(
2066            self.active_lowered_artifacts()
2067                .as_ref()
2068                .map(|artifacts| Arc::ptr_eq(artifacts, &specialization.lowered_artifacts)),
2069            Some(true)
2070        );
2071        debug_assert_eq!(
2072            self.lowered_runtime().value_program().scratch_slots(),
2073            specialization
2074                .lowered_artifacts
2075                .lowered_runtime
2076                .value_program()
2077                .scratch_slots()
2078        );
2079    }
2080    fn lower_expression_runtime_artifacts(
2081        expression_ir: &ir::ExpressionIR,
2082        values: &[PrecomputedCachedIntegral],
2083    ) -> LadduResult<LoweredArtifactCacheState> {
2084        let parameter_node_indices = values
2085            .iter()
2086            .map(|value| value.parameter_node_index)
2087            .collect();
2088        let mul_node_indices = values.iter().map(|value| value.mul_node_index).collect();
2089        let lowered_parameter_factors = Self::lower_cached_parameter_factors(expression_ir);
2090        let residual_runtime = Self::lower_residual_runtime(expression_ir, values);
2091        let lowered_runtime = lowered::LoweredExpressionRuntime::from_ir_value_gradient(
2092            expression_ir,
2093        )
2094        .map_err(|error| {
2095            LadduError::Custom(format!(
2096                "Failed to lower expression runtime for specialized IR: {error:?}"
2097            ))
2098        })?;
2099        Ok(LoweredArtifactCacheState {
2100            parameter_node_indices,
2101            mul_node_indices,
2102            lowered_parameter_factors,
2103            residual_runtime,
2104            lowered_runtime,
2105        })
2106    }
2107    fn lowered_artifact_signature_matches(
2108        artifacts: &LoweredArtifactCacheState,
2109        values: &[PrecomputedCachedIntegral],
2110    ) -> bool {
2111        artifacts.parameter_node_indices.len() == values.len()
2112            && artifacts.mul_node_indices.len() == values.len()
2113            && artifacts
2114                .parameter_node_indices
2115                .iter()
2116                .copied()
2117                .eq(values.iter().map(|value| value.parameter_node_index))
2118            && artifacts
2119                .mul_node_indices
2120                .iter()
2121                .copied()
2122                .eq(values.iter().map(|value| value.mul_node_index))
2123    }
2124    fn build_expression_specialization(
2125        &self,
2126        resources: &Resources,
2127        key: CachedIntegralCacheKey,
2128    ) -> LadduResult<ExpressionSpecializationState> {
2129        let ir_compile_start = Instant::now();
2130        let expression_ir = self.compile_expression_ir_for_active_mask(&resources.active);
2131        let ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
2132        let cached_integrals_start = Instant::now();
2133        let values = Self::precompute_cached_integrals_at_load(
2134            &expression_ir,
2135            &self.amplitudes,
2136            &self.amplitude_use_sites,
2137            resources,
2138            &self.dataset,
2139            self.resources.read().n_free_parameters(),
2140        )?;
2141        let cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
2142        let execution_sets = expression_ir.normalization_execution_sets().clone();
2143        let active_mask_key = resources.active.clone();
2144        let cached_lowered_artifacts = {
2145            let lowered_artifact_cache = self.ir_planning.lowered_artifact_cache.read();
2146            lowered_artifact_cache
2147                .get(&active_mask_key)
2148                .cloned()
2149                .filter(|artifacts| Self::lowered_artifact_signature_matches(artifacts, &values))
2150        };
2151        let lowered_artifacts = if let Some(artifacts) = cached_lowered_artifacts {
2152            self.ir_planning
2153                .compile_metrics
2154                .write()
2155                .specialization_lowering_cache_hits += 1;
2156            artifacts
2157        } else {
2158            let lowering_start = Instant::now();
2159            let artifacts = Arc::new(
2160                Self::lower_expression_runtime_artifacts(&expression_ir, &values)
2161                    .expect("specialized lowered runtime should build"),
2162            );
2163            let lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
2164            self.ir_planning
2165                .lowered_artifact_cache
2166                .write()
2167                .insert(active_mask_key, artifacts.clone());
2168            let mut compile_metrics = self.ir_planning.compile_metrics.write();
2169            compile_metrics.specialization_lowering_cache_misses += 1;
2170            compile_metrics.specialization_lowering_nanos += lowering_nanos;
2171            artifacts
2172        };
2173        let mut compile_metrics = self.ir_planning.compile_metrics.write();
2174        compile_metrics.specialization_cache_misses += 1;
2175        compile_metrics.specialization_ir_compile_nanos += ir_compile_nanos;
2176        compile_metrics.specialization_cached_integrals_nanos += cached_integrals_nanos;
2177        Ok(ExpressionSpecializationState {
2178            cached_integrals: Arc::new(CachedIntegralCacheState {
2179                key,
2180                expression_ir,
2181                values,
2182                execution_sets,
2183            }),
2184            lowered_artifacts,
2185        })
2186    }
2187    fn ensure_expression_specialization(
2188        &self,
2189        resources: &Resources,
2190    ) -> LadduResult<ExpressionSpecializationState> {
2191        let key = Self::cached_integral_cache_key(resources.active.clone(), &self.dataset);
2192        if let Some(state) = self.ir_planning.cached_integrals.read().as_ref() {
2193            if state.key == key {
2194                return Ok(ExpressionSpecializationState {
2195                    cached_integrals: state.clone(),
2196                    lowered_artifacts: self
2197                        .active_lowered_artifacts()
2198                        .expect("active lowered artifacts should exist for cached specialization"),
2199                });
2200            }
2201        }
2202        let cached_specialization = {
2203            let specialization_cache = self.ir_planning.specialization_cache.read();
2204            specialization_cache.get(&key).cloned()
2205        };
2206        if let Some(specialization) = cached_specialization {
2207            let restore_start = Instant::now();
2208            self.ir_planning.specialization_metrics.write().cache_hits += 1;
2209            self.install_expression_specialization(&specialization);
2210            *self.ir_planning.specialization_status.write() =
2211                Some(ExpressionSpecializationStatus {
2212                    origin: ExpressionSpecializationOrigin::CacheHitRestore,
2213                });
2214            let restore_nanos = restore_start.elapsed().as_nanos() as u64;
2215            let mut compile_metrics = self.ir_planning.compile_metrics.write();
2216            compile_metrics.specialization_cache_hits += 1;
2217            compile_metrics.specialization_cache_restore_nanos += restore_nanos;
2218            return Ok(specialization);
2219        }
2220        let specialization = self.build_expression_specialization(resources, key.clone())?;
2221        self.ir_planning.specialization_metrics.write().cache_misses += 1;
2222        self.ir_planning
2223            .specialization_cache
2224            .write()
2225            .insert(key, specialization.clone());
2226        self.install_expression_specialization(&specialization);
2227        let origin = if self.ir_planning.specialization_cache.read().len() == 1 {
2228            ExpressionSpecializationOrigin::InitialLoad
2229        } else {
2230            ExpressionSpecializationOrigin::CacheMissRebuild
2231        };
2232        *self.ir_planning.specialization_status.write() =
2233            Some(ExpressionSpecializationStatus { origin });
2234        Ok(specialization)
2235    }
2236    fn rebuild_runtime_specializations(&self, resources: &Resources) {
2237        let _ = self.ensure_expression_specialization(resources);
2238    }
2239    fn refresh_runtime_specializations(&self) {
2240        let resources = self.resources.read();
2241        self.rebuild_runtime_specializations(&resources);
2242    }
2243    fn cached_integral_cache_key(
2244        active_mask: Vec<bool>,
2245        dataset: &Dataset,
2246    ) -> CachedIntegralCacheKey {
2247        let (weights_ptr, weights_local_len) = dataset.local_weight_cache_key();
2248        CachedIntegralCacheKey {
2249            active_mask,
2250            n_events_local: dataset.n_events_local(),
2251            weights_local_len,
2252            weighted_sum_bits: dataset.n_events_weighted_local().to_bits(),
2253            weights_ptr,
2254        }
2255    }
2256    fn precompute_cached_integrals_at_load(
2257        expression_ir: &ir::ExpressionIR,
2258        amplitudes: &[Box<dyn Amplitude>],
2259        amplitude_use_sites: &[AmplitudeUseSite],
2260        resources: &Resources,
2261        dataset: &Dataset,
2262        n_free_parameters: usize,
2263    ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
2264        let descriptors = expression_ir.cached_integral_descriptors();
2265        if descriptors.is_empty() {
2266            return Ok(Vec::new());
2267        }
2268        let execution_sets = expression_ir.normalization_execution_sets();
2269        let seed_parameters = vec![0.0; n_free_parameters];
2270        let parameters = resources.parameter_map.assemble(&seed_parameters)?;
2271        let mut amplitude_values = vec![Complex64::ZERO; amplitude_use_sites.len()];
2272        let mut compute_values = vec![Complex64::ZERO; amplitudes.len()];
2273        let mut value_slots = vec![Complex64::ZERO; expression_ir.node_count()];
2274        let active_set = resources.active_indices();
2275        let cache_active_indices = execution_sets
2276            .cached_cache_amplitudes
2277            .iter()
2278            .copied()
2279            .filter(|index| active_set.binary_search(index).is_ok())
2280            .collect::<Vec<_>>();
2281        let mut weighted_cache_sums = vec![Complex64::ZERO; descriptors.len()];
2282        for (cache, event) in resources.caches.iter().zip(dataset.weights_local().iter()) {
2283            amplitude_values.fill(Complex64::ZERO);
2284            compute_values.fill(Complex64::ZERO);
2285            let mut computed = vec![false; amplitudes.len()];
2286            for &use_site_idx in &cache_active_indices {
2287                let amp_idx = amplitude_use_sites[use_site_idx].amplitude_index;
2288                if !computed[amp_idx] {
2289                    compute_values[amp_idx] = amplitudes[amp_idx].compute(&parameters, cache);
2290                    computed[amp_idx] = true;
2291                }
2292                amplitude_values[use_site_idx] = compute_values[amp_idx];
2293            }
2294            expression_ir.evaluate_into(&amplitude_values, &mut value_slots);
2295            let weight = *event;
2296            for (descriptor_index, descriptor) in descriptors.iter().enumerate() {
2297                weighted_cache_sums[descriptor_index] +=
2298                    value_slots[descriptor.cache_node_index] * weight;
2299            }
2300        }
2301        Ok(descriptors
2302            .iter()
2303            .zip(weighted_cache_sums)
2304            .map(
2305                |(descriptor, weighted_cache_sum)| PrecomputedCachedIntegral {
2306                    mul_node_index: descriptor.mul_node_index,
2307                    parameter_node_index: descriptor.parameter_node_index,
2308                    cache_node_index: descriptor.cache_node_index,
2309                    coefficient: descriptor.coefficient,
2310                    weighted_cache_sum,
2311                },
2312            )
2313            .collect())
2314    }
2315    fn lower_cached_parameter_factors(
2316        expression_ir: &ir::ExpressionIR,
2317    ) -> Vec<Option<lowered::LoweredFactorRuntime>> {
2318        expression_ir
2319            .cached_integral_descriptors()
2320            .iter()
2321            .map(|descriptor| {
2322                lowered::LoweredFactorRuntime::from_ir_root_value_gradient(
2323                    expression_ir,
2324                    descriptor.parameter_node_index,
2325                )
2326                .ok()
2327            })
2328            .collect()
2329    }
2330    fn lower_residual_runtime(
2331        expression_ir: &ir::ExpressionIR,
2332        descriptors: &[PrecomputedCachedIntegral],
2333    ) -> Option<lowered::LoweredExpressionRuntime> {
2334        let mut zeroed_nodes = vec![false; expression_ir.node_count()];
2335        for descriptor in descriptors {
2336            if descriptor.mul_node_index < zeroed_nodes.len() {
2337                zeroed_nodes[descriptor.mul_node_index] = true;
2338            }
2339        }
2340        lowered::LoweredExpressionRuntime::from_ir_zeroed_value_gradient(
2341            expression_ir,
2342            &zeroed_nodes,
2343        )
2344        .ok()
2345    }
2346
2347    #[inline]
2348    fn fill_amplitude_values(
2349        &self,
2350        amplitude_values: &mut [Complex64],
2351        active_indices: &[usize],
2352        parameters: &Parameters,
2353        cache: &Cache,
2354    ) {
2355        amplitude_values.fill(Complex64::ZERO);
2356        let mut compute_values = vec![Complex64::ZERO; self.amplitudes.len()];
2357        let mut computed = vec![false; self.amplitudes.len()];
2358        for &use_site_idx in active_indices {
2359            let amp_idx = self.amplitude_use_sites[use_site_idx].amplitude_index;
2360            if !computed[amp_idx] {
2361                compute_values[amp_idx] = self.amplitudes[amp_idx].compute(parameters, cache);
2362                computed[amp_idx] = true;
2363            }
2364            amplitude_values[use_site_idx] = compute_values[amp_idx];
2365        }
2366    }
2367
2368    #[inline]
2369    fn fill_amplitude_gradients(
2370        &self,
2371        gradient_values: &mut [DVector<Complex64>],
2372        active_mask: &[bool],
2373        parameters: &Parameters,
2374        cache: &Cache,
2375    ) {
2376        let mut compute_gradients = vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
2377        let mut computed = vec![false; self.amplitudes.len()];
2378        for ((use_site, active), grad) in self
2379            .amplitude_use_sites
2380            .iter()
2381            .zip(active_mask.iter())
2382            .zip(gradient_values.iter_mut())
2383        {
2384            grad.fill(Complex64::ZERO);
2385            if *active {
2386                let amp_idx = use_site.amplitude_index;
2387                if !computed[amp_idx] {
2388                    self.amplitudes[amp_idx].compute_gradient(
2389                        parameters,
2390                        cache,
2391                        &mut compute_gradients[amp_idx],
2392                    );
2393                    computed[amp_idx] = true;
2394                }
2395                grad.copy_from(&compute_gradients[amp_idx]);
2396            }
2397        }
2398    }
2399
2400    #[inline]
2401    fn fill_amplitude_values_and_gradients(
2402        &self,
2403        amplitude_values: &mut [Complex64],
2404        gradient_values: &mut [DVector<Complex64>],
2405        active_indices: &[usize],
2406        active_mask: &[bool],
2407        parameters: &Parameters,
2408        cache: &Cache,
2409    ) {
2410        self.fill_amplitude_values(amplitude_values, active_indices, parameters, cache);
2411        self.fill_amplitude_gradients(gradient_values, active_mask, parameters, cache);
2412    }
2413
2414    #[doc(hidden)]
2415    pub fn fill_amplitude_values_and_gradients_public(
2416        &self,
2417        amplitude_values: &mut [Complex64],
2418        gradient_values: &mut [DVector<Complex64>],
2419        active_indices: &[usize],
2420        active_mask: &[bool],
2421        parameters: &Parameters,
2422        cache: &Cache,
2423    ) {
2424        self.fill_amplitude_values_and_gradients(
2425            amplitude_values,
2426            gradient_values,
2427            active_indices,
2428            active_mask,
2429            parameters,
2430            cache,
2431        );
2432    }
2433
2434    #[cfg(feature = "execution-context-prototype")]
2435    #[inline]
2436    fn evaluate_cache_gradient_with_scratch(
2437        &self,
2438        amplitude_values: &mut [Complex64],
2439        gradient_values: &mut [DVector<Complex64>],
2440        value_slots: &mut [Complex64],
2441        gradient_slots: &mut [DVector<Complex64>],
2442        active_indices: &[usize],
2443        active_mask: &[bool],
2444        parameters: &Parameters,
2445        cache: &Cache,
2446    ) -> DVector<Complex64> {
2447        self.fill_amplitude_values_and_gradients(
2448            amplitude_values,
2449            gradient_values,
2450            active_indices,
2451            active_mask,
2452            parameters,
2453            cache,
2454        );
2455        self.evaluate_expression_gradient_with_scratch(
2456            amplitude_values,
2457            gradient_values,
2458            value_slots,
2459            gradient_slots,
2460        )
2461    }
2462
2463    #[cfg(feature = "execution-context-prototype")]
2464    #[allow(dead_code)]
2465    #[inline]
2466    fn evaluate_cache_value_gradient_with_scratch(
2467        &self,
2468        amplitude_values: &mut [Complex64],
2469        gradient_values: &mut [DVector<Complex64>],
2470        value_slots: &mut [Complex64],
2471        gradient_slots: &mut [DVector<Complex64>],
2472        active_indices: &[usize],
2473        active_mask: &[bool],
2474        parameters: &Parameters,
2475        cache: &Cache,
2476    ) -> (Complex64, DVector<Complex64>) {
2477        self.fill_amplitude_values_and_gradients(
2478            amplitude_values,
2479            gradient_values,
2480            active_indices,
2481            active_mask,
2482            parameters,
2483            cache,
2484        );
2485        self.evaluate_expression_value_gradient_with_scratch(
2486            amplitude_values,
2487            gradient_values,
2488            value_slots,
2489            gradient_slots,
2490        )
2491    }
2492
2493    pub fn expression_slot_count(&self) -> usize {
2494        self.lowered_runtime_slot_count()
2495    }
2496    fn compile_expression_ir_for_active_mask(&self, active_mask: &[bool]) -> ir::ExpressionIR {
2497        let amplitude_dependencies = self
2498            .amplitude_use_sites
2499            .iter()
2500            .map(|use_site| {
2501                ir::DependenceClass::from(
2502                    self.amplitudes[use_site.amplitude_index].dependence_hint(),
2503                )
2504            })
2505            .collect::<Vec<_>>();
2506        let amplitude_realness = self
2507            .amplitude_use_sites
2508            .iter()
2509            .map(|use_site| self.amplitudes[use_site.amplitude_index].real_valued_hint())
2510            .collect::<Vec<_>>();
2511        ir::compile_expression_ir_with_real_hints(
2512            &self.expression,
2513            active_mask,
2514            &amplitude_dependencies,
2515            &amplitude_realness,
2516        )
2517    }
2518    fn lower_expression_runtime_for_active_mask(
2519        &self,
2520        active_mask: &[bool],
2521    ) -> LadduResult<lowered::LoweredExpressionRuntime> {
2522        let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
2523        lowered::LoweredExpressionRuntime::from_ir_value_gradient(&expression_ir).map_err(|error| {
2524            LadduError::Custom(format!(
2525                "Failed to lower active-mask runtime specialization: {error:?}"
2526            ))
2527        })
2528    }
2529    fn ensure_cached_integral_cache_state(
2530        &self,
2531        resources: &Resources,
2532    ) -> LadduResult<Arc<CachedIntegralCacheState>> {
2533        Ok(self
2534            .ensure_expression_specialization(resources)?
2535            .cached_integrals)
2536    }
2537
2538    fn evaluate_expression_runtime_value_with_scratch(
2539        &self,
2540        amplitude_values: &[Complex64],
2541        scratch: &mut [Complex64],
2542    ) -> Complex64 {
2543        let lowered_runtime = self.lowered_runtime();
2544        lowered_runtime
2545            .value_program()
2546            .evaluate_into(amplitude_values, scratch)
2547    }
2548
2549    #[doc(hidden)]
2550    pub fn evaluate_expression_value_with_program_snapshot(
2551        &self,
2552        program_snapshot: &ExpressionValueProgramSnapshot,
2553        amplitude_values: &[Complex64],
2554        scratch: &mut [Complex64],
2555    ) -> Complex64 {
2556        program_snapshot
2557            .lowered_program
2558            .evaluate_into(amplitude_values, scratch)
2559    }
2560
2561    fn evaluate_expression_runtime_gradient_with_scratch(
2562        &self,
2563        amplitude_values: &[Complex64],
2564        gradient_values: &[DVector<Complex64>],
2565        value_scratch: &mut [Complex64],
2566        gradient_scratch: &mut [DVector<Complex64>],
2567    ) -> DVector<Complex64> {
2568        let lowered_runtime = self.lowered_runtime();
2569        lowered_runtime.gradient_program().evaluate_gradient_into(
2570            amplitude_values,
2571            gradient_values,
2572            value_scratch,
2573            gradient_scratch,
2574        )
2575    }
2576
2577    fn evaluate_expression_runtime_value_gradient_with_scratch(
2578        &self,
2579        amplitude_values: &[Complex64],
2580        gradient_values: &[DVector<Complex64>],
2581        value_scratch: &mut [Complex64],
2582        gradient_scratch: &mut [DVector<Complex64>],
2583    ) -> (Complex64, DVector<Complex64>) {
2584        let lowered_runtime = self.lowered_runtime();
2585        lowered_runtime
2586            .value_gradient_program()
2587            .evaluate_value_gradient_into(
2588                amplitude_values,
2589                gradient_values,
2590                value_scratch,
2591                gradient_scratch,
2592            )
2593    }
2594
2595    fn evaluate_expression_runtime_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
2596        let lowered_runtime = self.lowered_runtime();
2597        let program = lowered_runtime.value_program();
2598        let mut scratch = vec![Complex64::ZERO; program.scratch_slots()];
2599        program.evaluate_into(amplitude_values, &mut scratch)
2600    }
2601
2602    fn evaluate_expression_runtime_gradient(
2603        &self,
2604        amplitude_values: &[Complex64],
2605        gradient_values: &[DVector<Complex64>],
2606    ) -> DVector<Complex64> {
2607        let lowered_runtime = self.lowered_runtime();
2608        let program = lowered_runtime.gradient_program();
2609        let mut value_scratch = vec![Complex64::ZERO; program.scratch_slots()];
2610        let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
2611        let mut gradient_scratch = vec![Complex64::ZERO; program.scratch_slots() * grad_dim];
2612        program.evaluate_gradient_into_flat(
2613            amplitude_values,
2614            gradient_values,
2615            &mut value_scratch,
2616            &mut gradient_scratch,
2617            grad_dim,
2618        )
2619    }
2620    /// Dependence classification for the compiled expression root.
2621    pub fn expression_root_dependence(&self) -> LadduResult<ExpressionDependence> {
2622        let resources = self.resources.read();
2623        Ok(self
2624            .ensure_cached_integral_cache_state(&resources)?
2625            .expression_ir
2626            .root_dependence()
2627            .into())
2628    }
2629    /// Dependence classification for each compiled expression node.
2630    pub fn expression_node_dependence_annotations(&self) -> LadduResult<Vec<ExpressionDependence>> {
2631        let resources = self.resources.read();
2632        Ok(self
2633            .ensure_cached_integral_cache_state(&resources)?
2634            .expression_ir
2635            .node_dependence_annotations()
2636            .iter()
2637            .copied()
2638            .map(Into::into)
2639            .collect())
2640    }
2641    /// Warning-level diagnostics for potentially inconsistent dependence hints.
2642    pub fn expression_dependence_warnings(&self) -> LadduResult<Vec<String>> {
2643        let resources = self.resources.read();
2644        Ok(self
2645            .ensure_cached_integral_cache_state(&resources)?
2646            .expression_ir
2647            .dependence_warnings()
2648            .to_vec())
2649    }
2650    /// Explain/debug view of IR normalization planning decomposition.
2651    pub fn expression_normalization_plan_explain(&self) -> LadduResult<NormalizationPlanExplain> {
2652        let resources = self.resources.read();
2653        Ok(self
2654            .ensure_cached_integral_cache_state(&resources)?
2655            .expression_ir
2656            .normalization_plan_explain()
2657            .into())
2658    }
2659    /// Explain/debug view of amplitude execution sets used by normalization evaluation.
2660    pub fn expression_normalization_execution_sets(
2661        &self,
2662    ) -> LadduResult<NormalizationExecutionSetsExplain> {
2663        let resources = self.resources.read();
2664        Ok(self
2665            .ensure_cached_integral_cache_state(&resources)?
2666            .execution_sets
2667            .clone()
2668            .into())
2669    }
2670    /// Cached integral terms precomputed at evaluator load.
2671    pub fn expression_precomputed_cached_integrals(
2672        &self,
2673    ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
2674        let resources = self.resources.read();
2675        Ok(self
2676            .ensure_cached_integral_cache_state(&resources)?
2677            .values
2678            .clone())
2679    }
2680    /// Derivative rules for cached separable terms evaluated at the given parameter point.
2681    ///
2682    /// Each returned term corresponds to a cached separable descriptor and contributes
2683    /// `weighted_gradient` to `d(normalization)/dp` prior to residual-term combination.
2684    pub fn expression_precomputed_cached_integral_gradient_terms(
2685        &self,
2686        parameters: &[f64],
2687    ) -> LadduResult<Vec<PrecomputedCachedIntegralGradientTerm>> {
2688        let resources = self.resources.read();
2689        let state = self.ensure_cached_integral_cache_state(&resources)?;
2690        if state.values.is_empty() {
2691            return Ok(Vec::new());
2692        }
2693
2694        let Some(cache) = resources.caches.first() else {
2695            return Ok(state
2696                .values
2697                .iter()
2698                .map(|descriptor| PrecomputedCachedIntegralGradientTerm {
2699                    mul_node_index: descriptor.mul_node_index,
2700                    parameter_node_index: descriptor.parameter_node_index,
2701                    cache_node_index: descriptor.cache_node_index,
2702                    coefficient: descriptor.coefficient,
2703                    weighted_gradient: DVector::zeros(parameters.len()),
2704                })
2705                .collect());
2706        };
2707
2708        let parameter_values = resources.parameter_map.assemble(parameters)?;
2709        let mut amplitude_values = vec![Complex64::ZERO; self.amplitude_use_sites.len()];
2710        self.fill_amplitude_values(
2711            &mut amplitude_values,
2712            resources.active_indices(),
2713            &parameter_values,
2714            cache,
2715        );
2716        let mut amplitude_gradients = (0..self.amplitude_use_sites.len())
2717            .map(|_| DVector::zeros(parameters.len()))
2718            .collect::<Vec<_>>();
2719        self.fill_amplitude_gradients(
2720            &mut amplitude_gradients,
2721            &resources.active,
2722            &parameter_values,
2723            cache,
2724        );
2725        let lowered_artifacts = self.active_lowered_artifacts();
2726        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2727        let mut gradient_slots = (0..state.expression_ir.node_count())
2728            .map(|_| DVector::zeros(parameters.len()))
2729            .collect::<Vec<_>>();
2730        let max_lowered_slots = lowered_artifacts
2731            .as_ref()
2732            .map(|artifacts| {
2733                artifacts
2734                    .lowered_parameter_factors
2735                    .iter()
2736                    .filter_map(|runtime| {
2737                        runtime
2738                            .as_ref()
2739                            .and_then(|runtime| runtime.gradient_program())
2740                            .map(|program| program.scratch_slots())
2741                    })
2742                    .max()
2743                    .unwrap_or(0)
2744            })
2745            .unwrap_or(0);
2746        let mut lowered_value_slots = vec![Complex64::ZERO; max_lowered_slots];
2747        let mut lowered_gradient_slots = vec![DVector::zeros(parameters.len()); max_lowered_slots];
2748        let use_lowered = lowered_artifacts.as_ref().is_some_and(|artifacts| {
2749            artifacts.lowered_parameter_factors.len() == state.values.len()
2750                && artifacts.lowered_parameter_factors.iter().all(|runtime| {
2751                    runtime
2752                        .as_ref()
2753                        .and_then(|runtime| runtime.gradient_program())
2754                        .is_some()
2755                })
2756        });
2757
2758        if !use_lowered {
2759            let _ = state.expression_ir.evaluate_gradient_into(
2760                &amplitude_values,
2761                &amplitude_gradients,
2762                &mut value_slots,
2763                &mut gradient_slots,
2764            );
2765        }
2766
2767        if use_lowered {
2768            let lowered_artifacts = lowered_artifacts.expect("lowered artifacts should exist");
2769            Ok(state
2770                .values
2771                .iter()
2772                .cloned()
2773                .zip(lowered_artifacts.lowered_parameter_factors.iter())
2774                .map(|(descriptor, runtime)| {
2775                    let parameter_gradient = runtime
2776                        .as_ref()
2777                        .and_then(|runtime| runtime.gradient_program())
2778                        .map(|program| {
2779                            program.evaluate_gradient_into(
2780                                &amplitude_values,
2781                                &amplitude_gradients,
2782                                &mut lowered_value_slots[..program.scratch_slots()],
2783                                &mut lowered_gradient_slots[..program.scratch_slots()],
2784                            )
2785                        })
2786                        .unwrap_or_else(|| DVector::zeros(parameters.len()));
2787                    let weighted_gradient = parameter_gradient.map(|value| {
2788                        value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2789                    });
2790                    PrecomputedCachedIntegralGradientTerm {
2791                        mul_node_index: descriptor.mul_node_index,
2792                        parameter_node_index: descriptor.parameter_node_index,
2793                        cache_node_index: descriptor.cache_node_index,
2794                        coefficient: descriptor.coefficient,
2795                        weighted_gradient,
2796                    }
2797                })
2798                .collect())
2799        } else {
2800            Ok(state
2801                .values
2802                .iter()
2803                .map(|descriptor| {
2804                    let parameter_gradient = gradient_slots
2805                        .get(descriptor.parameter_node_index)
2806                        .cloned()
2807                        .unwrap_or_else(|| DVector::zeros(parameters.len()));
2808                    let weighted_gradient = parameter_gradient.map(|value| {
2809                        value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2810                    });
2811                    PrecomputedCachedIntegralGradientTerm {
2812                        mul_node_index: descriptor.mul_node_index,
2813                        parameter_node_index: descriptor.parameter_node_index,
2814                        cache_node_index: descriptor.cache_node_index,
2815                        coefficient: descriptor.coefficient,
2816                        weighted_gradient,
2817                    }
2818                })
2819                .collect())
2820        }
2821    }
2822    fn evaluate_cached_weighted_value_sum_ir(
2823        &self,
2824        state: &CachedIntegralCacheState,
2825        amplitude_values: &[Complex64],
2826    ) -> f64 {
2827        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2828        let _ = state
2829            .expression_ir
2830            .evaluate_into(amplitude_values, &mut value_slots);
2831        state
2832            .values
2833            .iter()
2834            .map(|descriptor| {
2835                let parameter_factor = value_slots[descriptor.parameter_node_index];
2836                (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2837                    .re
2838            })
2839            .sum()
2840    }
2841    fn evaluate_cached_weighted_value_sum_lowered(
2842        &self,
2843        state: &CachedIntegralCacheState,
2844        lowered_artifacts: &LoweredArtifactCacheState,
2845        amplitude_values: &[Complex64],
2846    ) -> Option<f64> {
2847        let max_slots = lowered_artifacts
2848            .lowered_parameter_factors
2849            .iter()
2850            .filter_map(|runtime| {
2851                runtime
2852                    .as_ref()
2853                    .and_then(|runtime| runtime.value_program())
2854                    .map(|program| program.scratch_slots())
2855            })
2856            .max()
2857            .unwrap_or(0);
2858        let mut value_slots = vec![Complex64::ZERO; max_slots];
2859        let mut total = 0.0;
2860        for (descriptor, runtime) in state
2861            .values
2862            .iter()
2863            .zip(lowered_artifacts.lowered_parameter_factors.iter())
2864        {
2865            let parameter_factor = runtime
2866                .as_ref()
2867                .and_then(|runtime| runtime.value_program())
2868                .map(|program| {
2869                    program.evaluate_into(
2870                        amplitude_values,
2871                        &mut value_slots[..program.scratch_slots()],
2872                    )
2873                })?;
2874            total +=
2875                (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2876                    .re;
2877        }
2878        Some(total)
2879    }
2880    fn evaluate_cached_weighted_gradient_sum_ir(
2881        &self,
2882        state: &CachedIntegralCacheState,
2883        amplitude_values: &[Complex64],
2884        amplitude_gradients: &[DVector<Complex64>],
2885        grad_dim: usize,
2886    ) -> DVector<f64> {
2887        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2888        let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2889        let _ = state.expression_ir.evaluate_gradient_into(
2890            amplitude_values,
2891            amplitude_gradients,
2892            &mut value_slots,
2893            &mut gradient_slots,
2894        );
2895        state
2896            .values
2897            .iter()
2898            .fold(DVector::zeros(grad_dim), |mut accum, descriptor| {
2899                let parameter_gradient = &gradient_slots[descriptor.parameter_node_index];
2900                let coefficient = descriptor.coefficient as f64;
2901                for (accum_item, gradient_item) in accum.iter_mut().zip(parameter_gradient.iter()) {
2902                    *accum_item +=
2903                        (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2904                }
2905                accum
2906            })
2907    }
2908    fn evaluate_cached_weighted_gradient_sum_lowered(
2909        &self,
2910        state: &CachedIntegralCacheState,
2911        lowered_artifacts: &LoweredArtifactCacheState,
2912        amplitude_values: &[Complex64],
2913        amplitude_gradients: &[DVector<Complex64>],
2914        grad_dim: usize,
2915    ) -> Option<DVector<f64>> {
2916        let max_value_slots = lowered_artifacts
2917            .lowered_parameter_factors
2918            .iter()
2919            .filter_map(|runtime| {
2920                runtime
2921                    .as_ref()
2922                    .and_then(|runtime| runtime.gradient_program())
2923                    .map(|program| program.scratch_slots())
2924            })
2925            .max()
2926            .unwrap_or(0);
2927        let mut value_slots = vec![Complex64::ZERO; max_value_slots];
2928        let mut gradient_slots = vec![Complex64::ZERO; max_value_slots * grad_dim];
2929        let mut total = DVector::zeros(grad_dim);
2930        for (descriptor, runtime) in state
2931            .values
2932            .iter()
2933            .zip(lowered_artifacts.lowered_parameter_factors.iter())
2934        {
2935            let parameter_gradient = runtime
2936                .as_ref()
2937                .and_then(|runtime| runtime.gradient_program())
2938                .map(|program| {
2939                    program.evaluate_gradient_into_flat(
2940                        amplitude_values,
2941                        amplitude_gradients,
2942                        &mut value_slots[..program.scratch_slots()],
2943                        &mut gradient_slots[..program.scratch_slots() * grad_dim],
2944                        grad_dim,
2945                    )
2946                })?;
2947            let coefficient = descriptor.coefficient as f64;
2948            for (accum_item, gradient_item) in total.iter_mut().zip(parameter_gradient.iter()) {
2949                *accum_item += (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2950            }
2951        }
2952        Some(total)
2953    }
2954    fn evaluate_residual_value_ir(
2955        &self,
2956        state: &CachedIntegralCacheState,
2957        amplitude_values: &[Complex64],
2958    ) -> Complex64 {
2959        let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2960        for descriptor in &state.values {
2961            if descriptor.mul_node_index < zeroed_nodes.len() {
2962                zeroed_nodes[descriptor.mul_node_index] = true;
2963            }
2964        }
2965        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2966        state.expression_ir.evaluate_into_with_zeroed_nodes(
2967            amplitude_values,
2968            &mut value_slots,
2969            &zeroed_nodes,
2970        )
2971    }
2972    fn evaluate_residual_gradient_ir(
2973        &self,
2974        state: &CachedIntegralCacheState,
2975        amplitude_values: &[Complex64],
2976        amplitude_gradients: &[DVector<Complex64>],
2977        grad_dim: usize,
2978    ) -> DVector<Complex64> {
2979        let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2980        for descriptor in &state.values {
2981            if descriptor.mul_node_index < zeroed_nodes.len() {
2982                zeroed_nodes[descriptor.mul_node_index] = true;
2983            }
2984        }
2985        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2986        let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2987        state
2988            .expression_ir
2989            .evaluate_gradient_into_with_zeroed_nodes(
2990                amplitude_values,
2991                amplitude_gradients,
2992                &mut value_slots,
2993                &mut gradient_slots,
2994                &zeroed_nodes,
2995            )
2996    }
2997
2998    fn evaluate_weighted_value_sum_local_components(
2999        &self,
3000        parameters: &[f64],
3001    ) -> LadduResult<(f64, f64)> {
3002        let resources = self.resources.read();
3003        let parameters = resources.parameter_map.assemble(parameters)?;
3004        let amplitude_len = self.amplitude_use_sites.len();
3005        let state = self.ensure_cached_integral_cache_state(&resources)?;
3006        let lowered_artifacts = self.active_lowered_artifacts();
3007        let residual_value_slot_count = lowered_artifacts
3008            .as_ref()
3009            .and_then(|artifacts| {
3010                artifacts
3011                    .residual_runtime
3012                    .as_ref()
3013                    .map(|runtime| runtime.value_program())
3014                    .map(|program| program.scratch_slots())
3015            })
3016            .unwrap_or_else(|| self.expression_slot_count());
3017        let residual_value_program = lowered_artifacts
3018            .as_ref()
3019            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
3020            .map(|runtime| runtime.value_program());
3021        let cached_parameter_indices = &state.execution_sets.cached_parameter_amplitudes;
3022        let residual_active_indices = &state.execution_sets.residual_amplitudes;
3023        debug_assert!(cached_parameter_indices.iter().all(|&index| resources
3024            .active
3025            .get(index)
3026            .copied()
3027            .unwrap_or(false)));
3028        debug_assert!(residual_active_indices.iter().all(|&index| resources
3029            .active
3030            .get(index)
3031            .copied()
3032            .unwrap_or(false)));
3033        let cached_value_sum = {
3034            if let Some(cache) = resources.caches.first() {
3035                let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3036                self.fill_amplitude_values(
3037                    &mut amplitude_values,
3038                    cached_parameter_indices,
3039                    &parameters,
3040                    cache,
3041                );
3042                lowered_artifacts
3043                    .as_ref()
3044                    .and_then(|artifacts| {
3045                        self.evaluate_cached_weighted_value_sum_lowered(
3046                            &state,
3047                            artifacts,
3048                            &amplitude_values,
3049                        )
3050                    })
3051                    .unwrap_or_else(|| {
3052                        self.evaluate_cached_weighted_value_sum_ir(&state, &amplitude_values)
3053                    })
3054            } else {
3055                0.0
3056            }
3057        };
3058
3059        #[cfg(feature = "rayon")]
3060        let residual_sum: f64 = {
3061            resources
3062                .caches
3063                .par_iter()
3064                .zip(self.dataset.weights_local().par_iter())
3065                .map_init(
3066                    || {
3067                        (
3068                            vec![Complex64::ZERO; amplitude_len],
3069                            vec![Complex64::ZERO; residual_value_slot_count],
3070                        )
3071                    },
3072                    |(amplitude_values, value_slots), (cache, event)| {
3073                        self.fill_amplitude_values(
3074                            amplitude_values,
3075                            residual_active_indices,
3076                            &parameters,
3077                            cache,
3078                        );
3079                        {
3080                            let value = residual_value_program
3081                                .as_ref()
3082                                .map(|program| {
3083                                    program.evaluate_into(
3084                                        amplitude_values,
3085                                        &mut value_slots[..program.scratch_slots()],
3086                                    )
3087                                })
3088                                .unwrap_or_else(|| {
3089                                    self.evaluate_residual_value_ir(&state, amplitude_values)
3090                                });
3091                            *event * value.re
3092                        }
3093                    },
3094                )
3095                .sum()
3096        };
3097
3098        #[cfg(not(feature = "rayon"))]
3099        let residual_sum: f64 = {
3100            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3101            let mut value_slots = vec![Complex64::ZERO; residual_value_slot_count];
3102            resources
3103                .caches
3104                .iter()
3105                .zip(self.dataset.weights_local().iter())
3106                .map(|(cache, event)| {
3107                    self.fill_amplitude_values(
3108                        &mut amplitude_values,
3109                        &residual_active_indices,
3110                        &parameters,
3111                        cache,
3112                    );
3113                    {
3114                        let value = residual_value_program
3115                            .as_ref()
3116                            .map(|program| {
3117                                program.evaluate_into(
3118                                    &amplitude_values,
3119                                    &mut value_slots[..program.scratch_slots()],
3120                                )
3121                            })
3122                            .unwrap_or_else(|| {
3123                                self.evaluate_residual_value_ir(&state, &amplitude_values)
3124                            });
3125                        *event * value.re
3126                    }
3127                })
3128                .sum()
3129        };
3130        Ok((residual_sum, cached_value_sum))
3131    }
3132
3133    /// Weighted sum over local events of the real expression value.
3134    ///
3135    /// This returns `sum_e(weight_e * Re(L_e))`.
3136    pub fn evaluate_weighted_value_sum_local(&self, parameters: &[f64]) -> LadduResult<f64> {
3137        let (residual_sum, cached_value_sum) =
3138            self.evaluate_weighted_value_sum_local_components(parameters)?;
3139        Ok(residual_sum + cached_value_sum)
3140    }
3141
3142    #[cfg(feature = "mpi")]
3143    /// Weighted sum over all ranks of the real expression value.
3144    ///
3145    /// This returns `sum_{r,e}(weight_{r,e} * Re(L_{r,e}))`.
3146    pub fn evaluate_weighted_value_sum_mpi(
3147        &self,
3148        parameters: &[f64],
3149        world: &SimpleCommunicator,
3150    ) -> LadduResult<f64> {
3151        let (residual_sum_local, cached_value_sum_local) =
3152            self.evaluate_weighted_value_sum_local_components(parameters)?;
3153        let mut residual_sum = 0.0;
3154        world.all_reduce_into(
3155            &residual_sum_local,
3156            &mut residual_sum,
3157            mpi::collective::SystemOperation::sum(),
3158        );
3159        let mut cached_value_sum = 0.0;
3160        world.all_reduce_into(
3161            &cached_value_sum_local,
3162            &mut cached_value_sum,
3163            mpi::collective::SystemOperation::sum(),
3164        );
3165        Ok(residual_sum + cached_value_sum)
3166    }
3167
3168    /// Weighted sum over local events of the real gradient of the expression.
3169    ///
3170    /// This returns `sum_e(weight_e * Re(dL_e/dp))` for all free parameters.
3171    fn evaluate_weighted_gradient_sum_local_components(
3172        &self,
3173        parameters: &[f64],
3174    ) -> LadduResult<(DVector<f64>, DVector<f64>)> {
3175        let resources = self.resources.read();
3176        let parameters = resources.parameter_map.assemble(parameters)?;
3177        let amplitude_len = self.amplitude_use_sites.len();
3178        let grad_dim = parameters.len();
3179        let state = self.ensure_cached_integral_cache_state(&resources)?;
3180        let lowered_artifacts = self.active_lowered_artifacts();
3181        let active_index_set = resources.active_indices();
3182        let cached_parameter_indices = state
3183            .execution_sets
3184            .cached_parameter_amplitudes
3185            .iter()
3186            .copied()
3187            .filter(|index| active_index_set.binary_search(index).is_ok())
3188            .collect::<Vec<_>>();
3189        let residual_active_indices = state
3190            .execution_sets
3191            .residual_amplitudes
3192            .iter()
3193            .copied()
3194            .filter(|index| active_index_set.binary_search(index).is_ok())
3195            .collect::<Vec<_>>();
3196        let mut cached_parameter_mask = vec![false; amplitude_len];
3197        for &index in &cached_parameter_indices {
3198            cached_parameter_mask[index] = true;
3199        }
3200        let mut residual_active_mask = vec![false; amplitude_len];
3201        for &index in &residual_active_indices {
3202            residual_active_mask[index] = true;
3203        }
3204        let residual_gradient_program = lowered_artifacts
3205            .as_ref()
3206            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
3207            .map(|runtime| runtime.gradient_program());
3208        let residual_gradient_slot_count = residual_gradient_program
3209            .as_ref()
3210            .map(|program| program.scratch_slots())
3211            .unwrap_or_else(|| state.expression_ir.node_count());
3212        let cached_term_sum = {
3213            if let Some(cache) = resources.caches.first() {
3214                let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3215                self.fill_amplitude_values(
3216                    &mut amplitude_values,
3217                    &cached_parameter_indices,
3218                    &parameters,
3219                    cache,
3220                );
3221                let mut amplitude_gradients = (0..amplitude_len)
3222                    .map(|_| DVector::zeros(grad_dim))
3223                    .collect::<Vec<_>>();
3224                self.fill_amplitude_gradients(
3225                    &mut amplitude_gradients,
3226                    &cached_parameter_mask,
3227                    &parameters,
3228                    cache,
3229                );
3230                lowered_artifacts
3231                    .as_ref()
3232                    .and_then(|artifacts| {
3233                        self.evaluate_cached_weighted_gradient_sum_lowered(
3234                            &state,
3235                            artifacts,
3236                            &amplitude_values,
3237                            &amplitude_gradients,
3238                            grad_dim,
3239                        )
3240                    })
3241                    .unwrap_or_else(|| {
3242                        self.evaluate_cached_weighted_gradient_sum_ir(
3243                            &state,
3244                            &amplitude_values,
3245                            &amplitude_gradients,
3246                            grad_dim,
3247                        )
3248                    })
3249            } else {
3250                DVector::zeros(grad_dim)
3251            }
3252        };
3253
3254        #[cfg(feature = "rayon")]
3255        let residual_sum = {
3256            resources
3257                .caches
3258                .par_iter()
3259                .zip(self.dataset.weights_local().par_iter())
3260                .map_init(
3261                    || {
3262                        (
3263                            vec![Complex64::ZERO; amplitude_len],
3264                            vec![DVector::zeros(grad_dim); amplitude_len],
3265                            vec![Complex64::ZERO; residual_gradient_slot_count],
3266                            vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim],
3267                        )
3268                    },
3269                    |(amplitude_values, gradient_values, value_slots, gradient_slots),
3270                     (cache, event)| {
3271                        self.fill_amplitude_values_and_gradients(
3272                            amplitude_values,
3273                            gradient_values,
3274                            &residual_active_indices,
3275                            &residual_active_mask,
3276                            &parameters,
3277                            cache,
3278                        );
3279                        let gradient = residual_gradient_program
3280                            .as_ref()
3281                            .map(|program| {
3282                                program.evaluate_gradient_into_flat(
3283                                    amplitude_values,
3284                                    gradient_values,
3285                                    value_slots,
3286                                    gradient_slots,
3287                                    grad_dim,
3288                                )
3289                            })
3290                            .unwrap_or_else(|| {
3291                                self.evaluate_residual_gradient_ir(
3292                                    &state,
3293                                    amplitude_values,
3294                                    gradient_values,
3295                                    grad_dim,
3296                                )
3297                            });
3298                        gradient.map(|value| value.re).scale(*event)
3299                    },
3300                )
3301                .reduce(
3302                    || DVector::zeros(grad_dim),
3303                    |mut accum, value| {
3304                        accum += value;
3305                        accum
3306                    },
3307                )
3308        };
3309
3310        #[cfg(not(feature = "rayon"))]
3311        let residual_sum = {
3312            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3313            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3314            let mut value_slots = vec![Complex64::ZERO; residual_gradient_slot_count];
3315            let mut gradient_slots = vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim];
3316            resources
3317                .caches
3318                .iter()
3319                .zip(self.dataset.weights_local().iter())
3320                .map(|(cache, event)| {
3321                    self.fill_amplitude_values_and_gradients(
3322                        &mut amplitude_values,
3323                        &mut gradient_values,
3324                        &residual_active_indices,
3325                        &residual_active_mask,
3326                        &parameters,
3327                        cache,
3328                    );
3329                    let gradient = residual_gradient_program
3330                        .as_ref()
3331                        .map(|program| {
3332                            program.evaluate_gradient_into_flat(
3333                                &amplitude_values,
3334                                &gradient_values,
3335                                &mut value_slots,
3336                                &mut gradient_slots,
3337                                grad_dim,
3338                            )
3339                        })
3340                        .unwrap_or_else(|| {
3341                            self.evaluate_residual_gradient_ir(
3342                                &state,
3343                                &amplitude_values,
3344                                &gradient_values,
3345                                grad_dim,
3346                            )
3347                        });
3348                    gradient.map(|value| value.re).scale(*event)
3349                })
3350                .sum()
3351        };
3352        Ok((residual_sum, cached_term_sum))
3353    }
3354
3355    /// Weighted sum over local events of the real gradient of the expression.
3356    ///
3357    /// This returns `sum_e(weight_e * Re(dL_e/dp))` for all free parameters.
3358    pub fn evaluate_weighted_gradient_sum_local(
3359        &self,
3360        parameters: &[f64],
3361    ) -> LadduResult<DVector<f64>> {
3362        let (residual_sum, cached_term_sum) =
3363            self.evaluate_weighted_gradient_sum_local_components(parameters)?;
3364        Ok(residual_sum + cached_term_sum)
3365    }
3366
3367    #[cfg(feature = "mpi")]
3368    /// Weighted sum over all ranks of the real gradient of the expression.
3369    ///
3370    /// This returns `sum_{r,e}(weight_{r,e} * Re(dL_{r,e}/dp))`.
3371    pub fn evaluate_weighted_gradient_sum_mpi(
3372        &self,
3373        parameters: &[f64],
3374        world: &SimpleCommunicator,
3375    ) -> LadduResult<DVector<f64>> {
3376        let (residual_sum_local, cached_term_sum_local) =
3377            self.evaluate_weighted_gradient_sum_local_components(parameters)?;
3378        let mut residual_sum = vec![0.0; residual_sum_local.len()];
3379        world.all_reduce_into(
3380            residual_sum_local.as_slice(),
3381            &mut residual_sum,
3382            mpi::collective::SystemOperation::sum(),
3383        );
3384        let mut cached_term_sum = vec![0.0; cached_term_sum_local.len()];
3385        world.all_reduce_into(
3386            cached_term_sum_local.as_slice(),
3387            &mut cached_term_sum,
3388            mpi::collective::SystemOperation::sum(),
3389        );
3390        let mut total = DVector::from_vec(residual_sum);
3391        total += DVector::from_vec(cached_term_sum);
3392        Ok(total)
3393    }
3394
3395    pub fn evaluate_expression_value_with_scratch(
3396        &self,
3397        amplitude_values: &[Complex64],
3398        scratch: &mut [Complex64],
3399    ) -> Complex64 {
3400        self.evaluate_expression_runtime_value_with_scratch(amplitude_values, scratch)
3401    }
3402
3403    pub fn evaluate_expression_gradient_with_scratch(
3404        &self,
3405        amplitude_values: &[Complex64],
3406        gradient_values: &[DVector<Complex64>],
3407        value_scratch: &mut [Complex64],
3408        gradient_scratch: &mut [DVector<Complex64>],
3409    ) -> DVector<Complex64> {
3410        self.evaluate_expression_runtime_gradient_with_scratch(
3411            amplitude_values,
3412            gradient_values,
3413            value_scratch,
3414            gradient_scratch,
3415        )
3416    }
3417
3418    pub fn evaluate_expression_value_gradient_with_scratch(
3419        &self,
3420        amplitude_values: &[Complex64],
3421        gradient_values: &[DVector<Complex64>],
3422        value_scratch: &mut [Complex64],
3423        gradient_scratch: &mut [DVector<Complex64>],
3424    ) -> (Complex64, DVector<Complex64>) {
3425        self.evaluate_expression_runtime_value_gradient_with_scratch(
3426            amplitude_values,
3427            gradient_values,
3428            value_scratch,
3429            gradient_scratch,
3430        )
3431    }
3432
3433    pub fn evaluate_expression_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
3434        self.evaluate_expression_runtime_value(amplitude_values)
3435    }
3436
3437    pub fn evaluate_expression_gradient(
3438        &self,
3439        amplitude_values: &[Complex64],
3440        gradient_values: &[DVector<Complex64>],
3441    ) -> DVector<Complex64> {
3442        self.evaluate_expression_runtime_gradient(amplitude_values, gradient_values)
3443    }
3444
3445    /// Get the parameters used by this evaluator.
3446    pub fn parameters(&self) -> ParameterMap {
3447        self.resources.read().parameters()
3448    }
3449
3450    /// Number of free parameters.
3451    pub fn n_free(&self) -> usize {
3452        self.resources.read().n_free_parameters()
3453    }
3454
3455    /// Number of fixed parameters.
3456    pub fn n_fixed(&self) -> usize {
3457        self.resources.read().n_fixed_parameters()
3458    }
3459
3460    /// Total number of parameters.
3461    pub fn n_parameters(&self) -> usize {
3462        self.resources.read().n_parameters()
3463    }
3464
3465    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
3466        self.resources.read().fix_parameter(name, value)
3467    }
3468
3469    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
3470        self.resources.read().free_parameter(name)
3471    }
3472
3473    pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
3474        self.resources.write().rename_parameter(old, new)
3475    }
3476
3477    pub fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
3478        self.resources.write().rename_parameters(mapping)
3479    }
3480
3481    /// Activate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
3482    pub fn activate<T: AsRef<str>>(&self, name: T) {
3483        self.resources.write().activate(name);
3484        self.refresh_runtime_specializations();
3485    }
3486    /// Activate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
3487    pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
3488        self.resources.write().activate_strict(name)?;
3489        self.refresh_runtime_specializations();
3490        Ok(())
3491    }
3492
3493    /// Activate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
3494    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
3495        self.resources.write().activate_many(names);
3496        self.refresh_runtime_specializations();
3497    }
3498    /// Activate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
3499    pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
3500        self.resources.write().activate_many_strict(names)?;
3501        self.refresh_runtime_specializations();
3502        Ok(())
3503    }
3504
3505    /// Activate all registered [`Amplitude`]s.
3506    pub fn activate_all(&self) {
3507        self.resources.write().activate_all();
3508        self.refresh_runtime_specializations();
3509    }
3510
3511    /// Deactivate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
3512    pub fn deactivate<T: AsRef<str>>(&self, name: T) {
3513        self.resources.write().deactivate(name);
3514        self.refresh_runtime_specializations();
3515    }
3516
3517    /// Deactivate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
3518    pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
3519        self.resources.write().deactivate_strict(name)?;
3520        self.refresh_runtime_specializations();
3521        Ok(())
3522    }
3523
3524    /// Deactivate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
3525    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
3526        self.resources.write().deactivate_many(names);
3527        self.refresh_runtime_specializations();
3528    }
3529    /// Deactivate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
3530    pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
3531        self.resources.write().deactivate_many_strict(names)?;
3532        self.refresh_runtime_specializations();
3533        Ok(())
3534    }
3535
3536    /// Deactivate all tagged [`Amplitude`] use-sites.
3537    pub fn deactivate_all(&self) {
3538        self.resources.write().deactivate_all();
3539        self.refresh_runtime_specializations();
3540    }
3541
3542    /// Isolate [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
3543    pub fn isolate<T: AsRef<str>>(&self, name: T) {
3544        self.resources.write().isolate(name);
3545        self.refresh_runtime_specializations();
3546    }
3547
3548    /// Isolate [`Amplitude`] use-sites by tag or glob selector and return an error if no use-site matches.
3549    pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
3550        self.resources.write().isolate_strict(name)?;
3551        self.refresh_runtime_specializations();
3552        Ok(())
3553    }
3554
3555    /// Isolate several [`Amplitude`] use-sites by tag or glob selector, skipping missing selectors.
3556    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
3557        self.resources.write().isolate_many(names);
3558        self.refresh_runtime_specializations();
3559    }
3560
3561    /// Isolate several [`Amplitude`] use-sites by tag or glob selector and return an error if any selector has no matches.
3562    pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
3563        self.resources.write().isolate_many_strict(names)?;
3564        self.refresh_runtime_specializations();
3565        Ok(())
3566    }
3567
3568    /// Return a copy of the current active-amplitude mask.
3569    pub fn active_mask(&self) -> Vec<bool> {
3570        self.resources.read().active.clone()
3571    }
3572
3573    /// Apply a precomputed active-amplitude mask. Untagged use-sites cannot be deactivated.
3574    pub fn set_active_mask(&self, mask: &[bool]) -> LadduResult<()> {
3575        let resources = {
3576            let mut resources = self.resources.write();
3577            if mask.len() != resources.active.len() {
3578                return Err(LadduError::LengthMismatch {
3579                    context: "active amplitude mask".to_string(),
3580                    expected: resources.active.len(),
3581                    actual: mask.len(),
3582                });
3583            }
3584            resources.apply_active_mask(mask)?;
3585            resources.clone()
3586        };
3587        self.rebuild_runtime_specializations(&resources);
3588        Ok(())
3589    }
3590
3591    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
3592    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
3593    ///
3594    /// # Notes
3595    ///
3596    /// This method is not intended to be called in analyses but rather in writing methods
3597    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
3598    pub fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
3599        let resources = self.resources.read();
3600        let parameters = resources.parameter_map.assemble(parameters)?;
3601        let amplitude_len = self.amplitude_use_sites.len();
3602        let active_indices = resources.active_indices().to_vec();
3603        let slot_count = self.expression_value_slot_count();
3604        let program_snapshot = self.expression_value_program_snapshot();
3605        #[cfg(feature = "rayon")]
3606        {
3607            Ok(resources
3608                .caches
3609                .par_iter()
3610                .map_init(
3611                    || {
3612                        (
3613                            vec![Complex64::ZERO; amplitude_len],
3614                            vec![Complex64::ZERO; slot_count],
3615                        )
3616                    },
3617                    |(amplitude_values, expr_slots), cache| {
3618                        self.fill_amplitude_values(
3619                            amplitude_values,
3620                            &active_indices,
3621                            &parameters,
3622                            cache,
3623                        );
3624                        self.evaluate_expression_value_with_program_snapshot(
3625                            &program_snapshot,
3626                            amplitude_values,
3627                            expr_slots,
3628                        )
3629                    },
3630                )
3631                .collect())
3632        }
3633        #[cfg(not(feature = "rayon"))]
3634        {
3635            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3636            let mut expr_slots = vec![Complex64::ZERO; slot_count];
3637            Ok(resources
3638                .caches
3639                .iter()
3640                .map(|cache| {
3641                    self.fill_amplitude_values(
3642                        &mut amplitude_values,
3643                        &active_indices,
3644                        &parameters,
3645                        cache,
3646                    );
3647                    self.evaluate_expression_value_with_program_snapshot(
3648                        &program_snapshot,
3649                        &amplitude_values,
3650                        &mut expr_slots,
3651                    )
3652                })
3653                .collect())
3654        }
3655    }
3656
3657    /// Evaluate local events using an explicit active-amplitude mask without mutating evaluator state.
3658    pub fn evaluate_local_with_active_mask(
3659        &self,
3660        parameters: &[f64],
3661        active_mask: &[bool],
3662    ) -> LadduResult<Vec<Complex64>> {
3663        let resources = self.resources.read();
3664        if active_mask.len() != resources.active.len() {
3665            return Err(LadduError::LengthMismatch {
3666                context: "active amplitude mask".to_string(),
3667                expected: resources.active.len(),
3668                actual: active_mask.len(),
3669            });
3670        }
3671        let parameters = resources.parameter_map.assemble(parameters)?;
3672        let amplitude_len = self.amplitude_use_sites.len();
3673        let active_indices = active_mask
3674            .iter()
3675            .enumerate()
3676            .filter_map(|(index, &active)| if active { Some(index) } else { None })
3677            .collect::<Vec<_>>();
3678        let program_snapshot =
3679            self.expression_value_program_snapshot_for_active_mask(active_mask)?;
3680        let slot_count = self.expression_value_program_snapshot_slot_count(&program_snapshot);
3681        #[cfg(feature = "rayon")]
3682        {
3683            Ok(resources
3684                .caches
3685                .par_iter()
3686                .map_init(
3687                    || {
3688                        (
3689                            vec![Complex64::ZERO; amplitude_len],
3690                            vec![Complex64::ZERO; slot_count],
3691                        )
3692                    },
3693                    |(amplitude_values, expr_slots), cache| {
3694                        self.fill_amplitude_values(
3695                            amplitude_values,
3696                            &active_indices,
3697                            &parameters,
3698                            cache,
3699                        );
3700                        self.evaluate_expression_value_with_program_snapshot(
3701                            &program_snapshot,
3702                            amplitude_values,
3703                            expr_slots,
3704                        )
3705                    },
3706                )
3707                .collect())
3708        }
3709        #[cfg(not(feature = "rayon"))]
3710        {
3711            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3712            let mut expr_slots = vec![Complex64::ZERO; slot_count];
3713            Ok(resources
3714                .caches
3715                .iter()
3716                .map(|cache| {
3717                    self.fill_amplitude_values(
3718                        &mut amplitude_values,
3719                        &active_indices,
3720                        &parameters,
3721                        cache,
3722                    );
3723                    self.evaluate_expression_value_with_program_snapshot(
3724                        &program_snapshot,
3725                        &amplitude_values,
3726                        &mut expr_slots,
3727                    )
3728                })
3729                .collect())
3730        }
3731    }
3732
3733    /// Evaluate the stored expression over local events using a reusable execution context.
3734    #[cfg(feature = "execution-context-prototype")]
3735    pub fn evaluate_local_with_ctx(
3736        &self,
3737        parameters: &[f64],
3738        execution_context: &ExecutionContext,
3739    ) -> Vec<Complex64> {
3740        let resources = self.resources.read();
3741        let parameters = resources
3742            .parameter_map
3743            .assemble(parameters)
3744            .expect("parameter slice must match evaluator resources");
3745        let amplitude_len = self.amplitude_use_sites.len();
3746        let active_indices = resources.active_indices().to_vec();
3747        let slot_count = self.expression_value_slot_count();
3748        let program_snapshot = self.expression_value_program_snapshot();
3749        #[cfg(feature = "rayon")]
3750        {
3751            if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3752                return execution_context.install(|| {
3753                    resources
3754                        .caches
3755                        .par_iter()
3756                        .map_init(
3757                            || {
3758                                (
3759                                    vec![Complex64::ZERO; amplitude_len],
3760                                    vec![Complex64::ZERO; slot_count],
3761                                )
3762                            },
3763                            |(amplitude_values, expr_slots), cache| {
3764                                self.fill_amplitude_values(
3765                                    amplitude_values,
3766                                    &active_indices,
3767                                    &parameters,
3768                                    cache,
3769                                );
3770                                self.evaluate_expression_value_with_program_snapshot(
3771                                    &program_snapshot,
3772                                    amplitude_values,
3773                                    expr_slots,
3774                                )
3775                            },
3776                        )
3777                        .collect()
3778                });
3779            }
3780        }
3781        execution_context.with_scratch(|scratch| {
3782            let (amplitude_values, expr_slots) =
3783                scratch.reserve_value_workspaces(amplitude_len, slot_count);
3784            resources
3785                .caches
3786                .iter()
3787                .map(|cache| {
3788                    self.fill_amplitude_values(
3789                        amplitude_values,
3790                        &active_indices,
3791                        &parameters,
3792                        cache,
3793                    );
3794                    self.evaluate_expression_value_with_program_snapshot(
3795                        &program_snapshot,
3796                        amplitude_values,
3797                        expr_slots,
3798                    )
3799                })
3800                .collect()
3801        })
3802    }
3803
3804    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
3805    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
3806    ///
3807    /// # Notes
3808    ///
3809    /// This method is not intended to be called in analyses but rather in writing methods
3810    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
3811    #[cfg(feature = "mpi")]
3812    fn evaluate_mpi(
3813        &self,
3814        parameters: &[f64],
3815        world: &SimpleCommunicator,
3816    ) -> LadduResult<Vec<Complex64>> {
3817        let local_evaluation = self.evaluate_local(parameters)?;
3818        let n_events = self.dataset.n_events();
3819        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3820        let (counts, displs) = world.get_counts_displs(n_events);
3821        {
3822            // NOTE: gather is required here because the public MPI API returns full per-event outputs.
3823            // Do not replace with all-reduce unless semantics change to scalar aggregates only.
3824            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3825            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3826        }
3827        Ok(buffer)
3828    }
3829
3830    #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3831    fn evaluate_mpi_with_ctx(
3832        &self,
3833        parameters: &[f64],
3834        world: &SimpleCommunicator,
3835        execution_context: &ExecutionContext,
3836    ) -> Vec<Complex64> {
3837        let local_evaluation = self.evaluate_local_with_ctx(parameters, execution_context);
3838        let n_events = self.dataset.n_events();
3839        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3840        let (counts, displs) = world.get_counts_displs(n_events);
3841        {
3842            // NOTE: gather is required here because the public MPI API returns full per-event outputs.
3843            // Do not replace with all-reduce unless semantics change to scalar aggregates only.
3844            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3845            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3846        }
3847        buffer
3848    }
3849
3850    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
3851    /// [`Evaluator`] with the given values for free parameters.
3852    pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
3853        #[cfg(feature = "mpi")]
3854        {
3855            if let Some(world) = crate::mpi::get_world() {
3856                return self.evaluate_mpi(parameters, &world);
3857            }
3858        }
3859        self.evaluate_local(parameters)
3860    }
3861
3862    /// Evaluate the stored expression with a reusable execution context.
3863    ///
3864    /// This is intended for repeated calls with the same context instance.
3865    /// Thread behavior follows [`ThreadPolicy`](crate::ThreadPolicy) configured on
3866    /// [`ExecutionContext`](crate::ExecutionContext).
3867    #[cfg(feature = "execution-context-prototype")]
3868    pub fn evaluate_with_ctx(
3869        &self,
3870        parameters: &[f64],
3871        execution_context: &ExecutionContext,
3872    ) -> Vec<Complex64> {
3873        #[cfg(feature = "mpi")]
3874        {
3875            if let Some(world) = crate::mpi::get_world() {
3876                return self.evaluate_mpi_with_ctx(parameters, &world, execution_context);
3877            }
3878        }
3879        self.evaluate_local_with_ctx(parameters, execution_context)
3880    }
3881
3882    /// See [`Evaluator::evaluate_local`]. This method evaluates over a subset of events rather
3883    /// than all events in the total dataset.
3884    pub fn evaluate_batch_local(
3885        &self,
3886        parameters: &[f64],
3887        indices: &[usize],
3888    ) -> LadduResult<Vec<Complex64>> {
3889        let resources = self.resources.read();
3890        let parameters = resources.parameter_map.assemble(parameters)?;
3891        let amplitude_len = self.amplitude_use_sites.len();
3892        let active_indices = resources.active_indices().to_vec();
3893        let slot_count = self.expression_value_slot_count();
3894        let program_snapshot = self.expression_value_program_snapshot();
3895        #[cfg(feature = "rayon")]
3896        {
3897            Ok(indices
3898                .par_iter()
3899                .map_init(
3900                    || {
3901                        (
3902                            vec![Complex64::ZERO; amplitude_len],
3903                            vec![Complex64::ZERO; slot_count],
3904                        )
3905                    },
3906                    |(amplitude_values, expr_slots), &idx| {
3907                        let cache = &resources.caches[idx];
3908                        self.fill_amplitude_values(
3909                            amplitude_values,
3910                            &active_indices,
3911                            &parameters,
3912                            cache,
3913                        );
3914                        self.evaluate_expression_value_with_program_snapshot(
3915                            &program_snapshot,
3916                            amplitude_values,
3917                            expr_slots,
3918                        )
3919                    },
3920                )
3921                .collect())
3922        }
3923        #[cfg(not(feature = "rayon"))]
3924        {
3925            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3926            let mut expr_slots = vec![Complex64::ZERO; slot_count];
3927            Ok(indices
3928                .iter()
3929                .map(|&idx| {
3930                    let cache = &resources.caches[idx];
3931                    self.fill_amplitude_values(
3932                        &mut amplitude_values,
3933                        &active_indices,
3934                        &parameters,
3935                        cache,
3936                    );
3937                    self.evaluate_expression_value_with_program_snapshot(
3938                        &program_snapshot,
3939                        &amplitude_values,
3940                        &mut expr_slots,
3941                    )
3942                })
3943                .collect())
3944        }
3945    }
3946
3947    /// See [`Evaluator::evaluate_mpi`]. This method evaluates over a subset of events rather
3948    /// than all events in the total dataset.
3949    #[cfg(feature = "mpi")]
3950    fn evaluate_batch_mpi(
3951        &self,
3952        parameters: &[f64],
3953        indices: &[usize],
3954        world: &SimpleCommunicator,
3955    ) -> LadduResult<Vec<Complex64>> {
3956        let total = self.dataset.n_events();
3957        let locals = world.locals_from_globals(indices, total);
3958        let local_evaluation = self.evaluate_batch_local(parameters, &locals)?;
3959        Ok(world.all_gather_batched_partitioned(&local_evaluation, indices, total, None))
3960    }
3961
3962    /// Evaluate the stored [`Expression`] over a subset of events in the [`Dataset`] stored by the
3963    /// [`Evaluator`] with the given values for free parameters. See also [`Evaluator::evaluate`].
3964    pub fn evaluate_batch(
3965        &self,
3966        parameters: &[f64],
3967        indices: &[usize],
3968    ) -> LadduResult<Vec<Complex64>> {
3969        #[cfg(feature = "mpi")]
3970        {
3971            if let Some(world) = crate::mpi::get_world() {
3972                return self.evaluate_batch_mpi(parameters, indices, &world);
3973            }
3974        }
3975        self.evaluate_batch_local(parameters, indices)
3976    }
3977
3978    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
3979    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
3980    ///
3981    /// # Notes
3982    ///
3983    /// This method is not intended to be called in analyses but rather in writing methods
3984    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
3985    pub fn evaluate_gradient_local(
3986        &self,
3987        parameters: &[f64],
3988    ) -> LadduResult<Vec<DVector<Complex64>>> {
3989        let resources = self.resources.read();
3990        let parameters = resources.parameter_map.assemble(parameters)?;
3991        let amplitude_len = self.amplitude_use_sites.len();
3992        let grad_dim = parameters.len();
3993        let active_indices = resources.active_indices().to_vec();
3994        let lowered_runtime = self.lowered_runtime();
3995        let gradient_program = lowered_runtime.gradient_program();
3996        let slot_count = self.expression_gradient_slot_count();
3997        #[cfg(feature = "rayon")]
3998        {
3999            Ok(resources
4000                .caches
4001                .par_iter()
4002                .map_init(
4003                    || {
4004                        (
4005                            vec![Complex64::ZERO; amplitude_len],
4006                            vec![DVector::zeros(grad_dim); amplitude_len],
4007                            vec![Complex64::ZERO; slot_count],
4008                            vec![Complex64::ZERO; slot_count * grad_dim],
4009                        )
4010                    },
4011                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
4012                        self.fill_amplitude_values_and_gradients(
4013                            amplitude_values,
4014                            gradient_values,
4015                            &active_indices,
4016                            &resources.active,
4017                            &parameters,
4018                            cache,
4019                        );
4020                        gradient_program.evaluate_gradient_into_flat(
4021                            amplitude_values,
4022                            gradient_values,
4023                            value_slots,
4024                            gradient_slots,
4025                            grad_dim,
4026                        )
4027                    },
4028                )
4029                .collect())
4030        }
4031        #[cfg(not(feature = "rayon"))]
4032        {
4033            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4034            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4035            let mut value_slots = vec![Complex64::ZERO; slot_count];
4036            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4037            Ok(resources
4038                .caches
4039                .iter()
4040                .map(|cache| {
4041                    self.fill_amplitude_values_and_gradients(
4042                        &mut amplitude_values,
4043                        &mut gradient_values,
4044                        &active_indices,
4045                        &resources.active,
4046                        &parameters,
4047                        cache,
4048                    );
4049                    gradient_program.evaluate_gradient_into_flat(
4050                        &amplitude_values,
4051                        &gradient_values,
4052                        &mut value_slots,
4053                        &mut gradient_slots,
4054                        grad_dim,
4055                    )
4056                })
4057                .collect())
4058        }
4059    }
4060
4061    /// Evaluate the gradient over local events using a reusable execution context.
4062    #[cfg(feature = "execution-context-prototype")]
4063    pub fn evaluate_gradient_local_with_ctx(
4064        &self,
4065        parameters: &[f64],
4066        execution_context: &ExecutionContext,
4067    ) -> Vec<DVector<Complex64>> {
4068        let resources = self.resources.read();
4069        let parameters = resources
4070            .parameter_map
4071            .assemble(parameters)
4072            .expect("parameter slice must match evaluator resources");
4073        let amplitude_len = self.amplitude_use_sites.len();
4074        let grad_dim = parameters.len();
4075        let active_indices = resources.active_indices().to_vec();
4076        let slot_count = self.expression_slot_count();
4077        #[cfg(feature = "rayon")]
4078        {
4079            if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
4080                return execution_context.install(|| {
4081                    resources
4082                        .caches
4083                        .par_iter()
4084                        .map_init(
4085                            || {
4086                                (
4087                                    vec![Complex64::ZERO; amplitude_len],
4088                                    vec![DVector::zeros(grad_dim); amplitude_len],
4089                                    vec![Complex64::ZERO; slot_count],
4090                                    vec![DVector::zeros(grad_dim); slot_count],
4091                                )
4092                            },
4093                            |(amplitude_values, gradient_values, value_slots, gradient_slots),
4094                             cache| {
4095                                self.evaluate_cache_gradient_with_scratch(
4096                                    amplitude_values,
4097                                    gradient_values,
4098                                    value_slots,
4099                                    gradient_slots,
4100                                    &active_indices,
4101                                    &resources.active,
4102                                    &parameters,
4103                                    cache,
4104                                )
4105                            },
4106                        )
4107                        .collect()
4108                });
4109            }
4110        }
4111        execution_context.with_scratch(|scratch| {
4112            let (amplitude_values, value_slots, gradient_values, gradient_slots) =
4113                scratch.reserve_gradient_workspaces(amplitude_len, slot_count, grad_dim);
4114            resources
4115                .caches
4116                .iter()
4117                .map(|cache| {
4118                    self.evaluate_cache_gradient_with_scratch(
4119                        amplitude_values,
4120                        gradient_values,
4121                        value_slots,
4122                        gradient_slots,
4123                        &active_indices,
4124                        &resources.active,
4125                        &parameters,
4126                        cache,
4127                    )
4128                })
4129                .collect()
4130        })
4131    }
4132
4133    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
4134    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
4135    ///
4136    /// # Notes
4137    ///
4138    /// This method is not intended to be called in analyses but rather in writing methods
4139    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
4140    #[cfg(feature = "mpi")]
4141    fn evaluate_gradient_mpi(
4142        &self,
4143        parameters: &[f64],
4144        world: &SimpleCommunicator,
4145    ) -> LadduResult<Vec<DVector<Complex64>>> {
4146        let local_evaluation = self.evaluate_gradient_local(parameters)?;
4147        let n_events = self.dataset.n_events();
4148        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
4149        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
4150        {
4151            // NOTE: gather is required here because the public MPI API returns full per-event gradients.
4152            // Do not replace with all-reduce unless semantics change to aggregate-only outputs.
4153            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4154            world.all_gather_varcount_into(
4155                &local_evaluation
4156                    .iter()
4157                    .flat_map(|v| v.data.as_vec())
4158                    .copied()
4159                    .collect::<Vec<_>>(),
4160                &mut partitioned_buffer,
4161            );
4162        }
4163        Ok(buffer
4164            .chunks(parameters.len())
4165            .map(DVector::from_row_slice)
4166            .collect())
4167    }
4168
4169    #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
4170    fn evaluate_gradient_mpi_with_ctx(
4171        &self,
4172        parameters: &[f64],
4173        world: &SimpleCommunicator,
4174        execution_context: &ExecutionContext,
4175    ) -> Vec<DVector<Complex64>> {
4176        let local_evaluation = self.evaluate_gradient_local_with_ctx(parameters, execution_context);
4177        let n_events = self.dataset.n_events();
4178        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
4179        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
4180        {
4181            // NOTE: gather is required here because the public MPI API returns full per-event gradients.
4182            // Do not replace with all-reduce unless semantics change to aggregate-only outputs.
4183            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4184            world.all_gather_varcount_into(
4185                &local_evaluation
4186                    .iter()
4187                    .flat_map(|v| v.data.as_vec())
4188                    .copied()
4189                    .collect::<Vec<_>>(),
4190                &mut partitioned_buffer,
4191            );
4192        }
4193        buffer
4194            .chunks(parameters.len())
4195            .map(DVector::from_row_slice)
4196            .collect()
4197    }
4198
4199    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
4200    /// [`Evaluator`] with the given values for free parameters.
4201    pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<Vec<DVector<Complex64>>> {
4202        #[cfg(feature = "mpi")]
4203        {
4204            if let Some(world) = crate::mpi::get_world() {
4205                return self.evaluate_gradient_mpi(parameters, &world);
4206            }
4207        }
4208        self.evaluate_gradient_local(parameters)
4209    }
4210
4211    /// Evaluate the gradient with a reusable execution context.
4212    ///
4213    /// This is intended for repeated calls with the same context instance.
4214    /// Thread behavior follows [`ThreadPolicy`](crate::ThreadPolicy) configured on
4215    /// [`ExecutionContext`](crate::ExecutionContext).
4216    #[cfg(feature = "execution-context-prototype")]
4217    pub fn evaluate_gradient_with_ctx(
4218        &self,
4219        parameters: &[f64],
4220        execution_context: &ExecutionContext,
4221    ) -> Vec<DVector<Complex64>> {
4222        #[cfg(feature = "mpi")]
4223        {
4224            if let Some(world) = crate::mpi::get_world() {
4225                return self.evaluate_gradient_mpi_with_ctx(parameters, &world, execution_context);
4226            }
4227        }
4228        self.evaluate_gradient_local_with_ctx(parameters, execution_context)
4229    }
4230
4231    /// See [`Evaluator::evaluate_gradient_local`]. This method evaluates over a subset
4232    /// of events rather than all events in the total dataset.
4233    pub fn evaluate_gradient_batch_local(
4234        &self,
4235        parameters: &[f64],
4236        indices: &[usize],
4237    ) -> LadduResult<Vec<DVector<Complex64>>> {
4238        let resources = self.resources.read();
4239        let parameters = resources.parameter_map.assemble(parameters)?;
4240        let amplitude_len = self.amplitude_use_sites.len();
4241        let grad_dim = parameters.len();
4242        let active_indices = resources.active_indices().to_vec();
4243        let lowered_runtime = self.lowered_runtime();
4244        let gradient_program = lowered_runtime.gradient_program();
4245        let slot_count = self.expression_gradient_slot_count();
4246        #[cfg(feature = "rayon")]
4247        {
4248            Ok(indices
4249                .par_iter()
4250                .map_init(
4251                    || {
4252                        (
4253                            vec![Complex64::ZERO; amplitude_len],
4254                            vec![DVector::zeros(grad_dim); amplitude_len],
4255                            vec![Complex64::ZERO; slot_count],
4256                            vec![Complex64::ZERO; slot_count * grad_dim],
4257                        )
4258                    },
4259                    |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
4260                        let cache = &resources.caches[idx];
4261                        self.fill_amplitude_values_and_gradients(
4262                            amplitude_values,
4263                            gradient_values,
4264                            &active_indices,
4265                            &resources.active,
4266                            &parameters,
4267                            cache,
4268                        );
4269                        gradient_program.evaluate_gradient_into_flat(
4270                            amplitude_values,
4271                            gradient_values,
4272                            value_slots,
4273                            gradient_slots,
4274                            grad_dim,
4275                        )
4276                    },
4277                )
4278                .collect())
4279        }
4280        #[cfg(not(feature = "rayon"))]
4281        {
4282            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4283            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4284            let mut value_slots = vec![Complex64::ZERO; slot_count];
4285            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4286            Ok(indices
4287                .iter()
4288                .map(|&idx| {
4289                    let cache = &resources.caches[idx];
4290                    self.fill_amplitude_values_and_gradients(
4291                        &mut amplitude_values,
4292                        &mut gradient_values,
4293                        &active_indices,
4294                        &resources.active,
4295                        &parameters,
4296                        cache,
4297                    );
4298                    gradient_program.evaluate_gradient_into_flat(
4299                        &amplitude_values,
4300                        &gradient_values,
4301                        &mut value_slots,
4302                        &mut gradient_slots,
4303                        grad_dim,
4304                    )
4305                })
4306                .collect())
4307        }
4308    }
4309
4310    /// See [`Evaluator::evaluate_gradient_mpi`]. This method evaluates over a subset
4311    /// of events rather than all events in the total dataset.
4312    #[cfg(feature = "mpi")]
4313    fn evaluate_gradient_batch_mpi(
4314        &self,
4315        parameters: &[f64],
4316        indices: &[usize],
4317        world: &SimpleCommunicator,
4318    ) -> LadduResult<Vec<DVector<Complex64>>> {
4319        let total = self.dataset.n_events();
4320        let locals = world.locals_from_globals(indices, total);
4321        let flattened_local_evaluation = self
4322            .evaluate_gradient_batch_local(parameters, &locals)?
4323            .iter()
4324            .flat_map(|g| g.data.as_vec().to_vec())
4325            .collect::<Vec<Complex64>>();
4326        Ok(world
4327            .all_gather_batched_partitioned(
4328                &flattened_local_evaluation,
4329                indices,
4330                total,
4331                Some(parameters.len()),
4332            )
4333            .chunks(parameters.len())
4334            .map(DVector::from_row_slice)
4335            .collect())
4336    }
4337
4338    /// Evaluate the gradient of the stored [`Expression`] over a subset of the
4339    /// events in the [`Dataset`] stored by the [`Evaluator`] with the given values
4340    /// for free parameters. See also [`Evaluator::evaluate_gradient`].
4341    pub fn evaluate_gradient_batch(
4342        &self,
4343        parameters: &[f64],
4344        indices: &[usize],
4345    ) -> LadduResult<Vec<DVector<Complex64>>> {
4346        #[cfg(feature = "mpi")]
4347        {
4348            if let Some(world) = crate::mpi::get_world() {
4349                return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
4350            }
4351        }
4352        self.evaluate_gradient_batch_local(parameters, indices)
4353    }
4354
4355    /// Evaluate the stored expression and its gradient over local events in one fused pass.
4356    pub fn evaluate_with_gradient_local(
4357        &self,
4358        parameters: &[f64],
4359    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
4360        let resources = self.resources.read();
4361        let parameters = resources.parameter_map.assemble(parameters)?;
4362        let amplitude_len = self.amplitude_use_sites.len();
4363        let grad_dim = parameters.len();
4364        let active_indices = resources.active_indices().to_vec();
4365        let lowered_runtime = self.lowered_runtime();
4366        let value_gradient_program = lowered_runtime.value_gradient_program();
4367        let slot_count = self.expression_value_gradient_slot_count();
4368        #[cfg(feature = "rayon")]
4369        {
4370            Ok(resources
4371                .caches
4372                .par_iter()
4373                .map_init(
4374                    || {
4375                        (
4376                            vec![Complex64::ZERO; amplitude_len],
4377                            vec![DVector::zeros(grad_dim); amplitude_len],
4378                            vec![Complex64::ZERO; slot_count],
4379                            vec![Complex64::ZERO; slot_count * grad_dim],
4380                        )
4381                    },
4382                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
4383                        self.fill_amplitude_values_and_gradients(
4384                            amplitude_values,
4385                            gradient_values,
4386                            &active_indices,
4387                            &resources.active,
4388                            &parameters,
4389                            cache,
4390                        );
4391                        value_gradient_program.evaluate_value_gradient_into_flat(
4392                            amplitude_values,
4393                            gradient_values,
4394                            value_slots,
4395                            gradient_slots,
4396                            grad_dim,
4397                        )
4398                    },
4399                )
4400                .collect())
4401        }
4402        #[cfg(not(feature = "rayon"))]
4403        {
4404            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4405            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4406            let mut value_slots = vec![Complex64::ZERO; slot_count];
4407            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4408            Ok(resources
4409                .caches
4410                .iter()
4411                .map(|cache| {
4412                    self.fill_amplitude_values_and_gradients(
4413                        &mut amplitude_values,
4414                        &mut gradient_values,
4415                        &active_indices,
4416                        &resources.active,
4417                        &parameters,
4418                        cache,
4419                    );
4420                    value_gradient_program.evaluate_value_gradient_into_flat(
4421                        &amplitude_values,
4422                        &gradient_values,
4423                        &mut value_slots,
4424                        &mut gradient_slots,
4425                        grad_dim,
4426                    )
4427                })
4428                .collect())
4429        }
4430    }
4431
4432    /// Evaluate local events and gradients with an explicit active-amplitude mask without mutating evaluator state.
4433    pub fn evaluate_with_gradient_local_with_active_mask(
4434        &self,
4435        parameters: &[f64],
4436        active_mask: &[bool],
4437    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
4438        let resources = self.resources.read();
4439        if active_mask.len() != resources.active.len() {
4440            return Err(LadduError::LengthMismatch {
4441                context: "active amplitude mask".to_string(),
4442                expected: resources.active.len(),
4443                actual: active_mask.len(),
4444            });
4445        }
4446        let parameters = resources.parameter_map.assemble(parameters)?;
4447        let amplitude_len = self.amplitude_use_sites.len();
4448        let grad_dim = parameters.len();
4449        let active_indices = active_mask
4450            .iter()
4451            .enumerate()
4452            .filter_map(|(index, &active)| if active { Some(index) } else { None })
4453            .collect::<Vec<_>>();
4454        let lowered_runtime = self.lower_expression_runtime_for_active_mask(active_mask)?;
4455        let slot_count = lowered_runtime.value_gradient_program().scratch_slots();
4456        #[cfg(feature = "rayon")]
4457        {
4458            Ok(resources
4459                .caches
4460                .par_iter()
4461                .map_init(
4462                    || {
4463                        (
4464                            vec![Complex64::ZERO; amplitude_len],
4465                            vec![DVector::zeros(grad_dim); amplitude_len],
4466                            vec![Complex64::ZERO; slot_count],
4467                            vec![Complex64::ZERO; slot_count * grad_dim],
4468                        )
4469                    },
4470                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
4471                        self.fill_amplitude_values_and_gradients(
4472                            amplitude_values,
4473                            gradient_values,
4474                            &active_indices,
4475                            active_mask,
4476                            &parameters,
4477                            cache,
4478                        );
4479                        lowered_runtime
4480                            .value_gradient_program()
4481                            .evaluate_value_gradient_into_flat(
4482                                amplitude_values,
4483                                gradient_values,
4484                                value_slots,
4485                                gradient_slots,
4486                                grad_dim,
4487                            )
4488                    },
4489                )
4490                .collect())
4491        }
4492        #[cfg(not(feature = "rayon"))]
4493        {
4494            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4495            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4496            let mut value_slots = vec![Complex64::ZERO; slot_count];
4497            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4498            Ok(resources
4499                .caches
4500                .iter()
4501                .map(|cache| {
4502                    self.fill_amplitude_values_and_gradients(
4503                        &mut amplitude_values,
4504                        &mut gradient_values,
4505                        &active_indices,
4506                        active_mask,
4507                        &parameters,
4508                        cache,
4509                    );
4510                    lowered_runtime
4511                        .value_gradient_program()
4512                        .evaluate_value_gradient_into_flat(
4513                            &amplitude_values,
4514                            &gradient_values,
4515                            &mut value_slots,
4516                            &mut gradient_slots,
4517                            grad_dim,
4518                        )
4519                })
4520                .collect())
4521        }
4522    }
4523
4524    /// Evaluate the stored expression and its gradient over a local subset of events in one fused pass.
4525    pub fn evaluate_with_gradient_batch_local(
4526        &self,
4527        parameters: &[f64],
4528        indices: &[usize],
4529    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
4530        let resources = self.resources.read();
4531        let parameters = resources.parameter_map.assemble(parameters)?;
4532        let amplitude_len = self.amplitude_use_sites.len();
4533        let grad_dim = parameters.len();
4534        let active_indices = resources.active_indices().to_vec();
4535        let lowered_runtime = self.lowered_runtime();
4536        let value_gradient_program = lowered_runtime.value_gradient_program();
4537        let slot_count = self.expression_value_gradient_slot_count();
4538        #[cfg(feature = "rayon")]
4539        {
4540            Ok(indices
4541                .par_iter()
4542                .map_init(
4543                    || {
4544                        (
4545                            vec![Complex64::ZERO; amplitude_len],
4546                            vec![DVector::zeros(grad_dim); amplitude_len],
4547                            vec![Complex64::ZERO; slot_count],
4548                            vec![Complex64::ZERO; slot_count * grad_dim],
4549                        )
4550                    },
4551                    |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
4552                        let cache = &resources.caches[idx];
4553                        self.fill_amplitude_values_and_gradients(
4554                            amplitude_values,
4555                            gradient_values,
4556                            &active_indices,
4557                            &resources.active,
4558                            &parameters,
4559                            cache,
4560                        );
4561                        value_gradient_program.evaluate_value_gradient_into_flat(
4562                            amplitude_values,
4563                            gradient_values,
4564                            value_slots,
4565                            gradient_slots,
4566                            grad_dim,
4567                        )
4568                    },
4569                )
4570                .collect())
4571        }
4572        #[cfg(not(feature = "rayon"))]
4573        {
4574            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4575            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4576            let mut value_slots = vec![Complex64::ZERO; slot_count];
4577            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4578            Ok(indices
4579                .iter()
4580                .map(|&idx| {
4581                    let cache = &resources.caches[idx];
4582                    self.fill_amplitude_values_and_gradients(
4583                        &mut amplitude_values,
4584                        &mut gradient_values,
4585                        &active_indices,
4586                        &resources.active,
4587                        &parameters,
4588                        cache,
4589                    );
4590                    value_gradient_program.evaluate_value_gradient_into_flat(
4591                        &amplitude_values,
4592                        &gradient_values,
4593                        &mut value_slots,
4594                        &mut gradient_slots,
4595                        grad_dim,
4596                    )
4597                })
4598                .collect())
4599        }
4600    }
4601}
4602
4603#[cfg(test)]
4604mod tests {
4605    use approx::assert_relative_eq;
4606    #[cfg(feature = "mpi")]
4607    use mpi_test::mpi_test;
4608    use serde::{Deserialize, Serialize};
4609
4610    use super::*;
4611    use crate::{
4612        amplitude::{AmplitudeID, Tags, TestAmplitude},
4613        data::{test_dataset, test_event, DatasetMetadata, Event, EventData},
4614        parameter,
4615        parameters::Parameter,
4616        resources::{Cache, ParameterID, Parameters, Resources, ScalarID},
4617        vectors::Vec4,
4618    };
4619
4620    #[derive(Clone, Serialize, Deserialize)]
4621    pub struct ComplexScalar {
4622        name: String,
4623        re: Parameter,
4624        pid_re: ParameterID,
4625        im: Parameter,
4626        pid_im: ParameterID,
4627    }
4628
4629    impl ComplexScalar {
4630        #[allow(clippy::new_ret_no_self)]
4631        pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
4632            Self {
4633                name: name.to_string(),
4634                re,
4635                pid_re: Default::default(),
4636                im,
4637                pid_im: Default::default(),
4638            }
4639            .into_expression()
4640        }
4641    }
4642
4643    #[typetag::serde]
4644    impl Amplitude for ComplexScalar {
4645        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4646            self.pid_re = resources.register_parameter(&self.re)?;
4647            self.pid_im = resources.register_parameter(&self.im)?;
4648            resources.register_amplitude(&self.name)
4649        }
4650
4651        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4652            Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
4653        }
4654
4655        fn compute_gradient(
4656            &self,
4657            parameters: &Parameters,
4658            _cache: &Cache,
4659            gradient: &mut DVector<Complex64>,
4660        ) {
4661            if let Some(ind) = parameters.free_index(self.pid_re) {
4662                gradient[ind] = Complex64::ONE;
4663            }
4664            if let Some(ind) = parameters.free_index(self.pid_im) {
4665                gradient[ind] = Complex64::I;
4666            }
4667        }
4668    }
4669
4670    #[derive(Clone, Serialize, Deserialize)]
4671    pub struct ParameterOnlyScalar {
4672        name: String,
4673        value: Parameter,
4674        pid: ParameterID,
4675    }
4676
4677    impl ParameterOnlyScalar {
4678        #[allow(clippy::new_ret_no_self)]
4679        pub fn new(name: &str, value: Parameter) -> LadduResult<Expression> {
4680            Self {
4681                name: name.to_string(),
4682                value,
4683                pid: Default::default(),
4684            }
4685            .into_expression()
4686        }
4687    }
4688
4689    #[typetag::serde]
4690    impl Amplitude for ParameterOnlyScalar {
4691        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4692            self.pid = resources.register_parameter(&self.value)?;
4693            resources.register_amplitude(&self.name)
4694        }
4695
4696        fn dependence_hint(&self) -> ExpressionDependence {
4697            ExpressionDependence::ParameterOnly
4698        }
4699
4700        fn real_valued_hint(&self) -> bool {
4701            true
4702        }
4703
4704        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4705            Complex64::new(parameters.get(self.pid), 0.0)
4706        }
4707    }
4708
4709    #[derive(Clone, Serialize, Deserialize)]
4710    pub struct CacheOnlyScalar {
4711        name: String,
4712        beam_energy: ScalarID,
4713    }
4714
4715    impl CacheOnlyScalar {
4716        #[allow(clippy::new_ret_no_self)]
4717        pub fn new(name: &str) -> LadduResult<Expression> {
4718            Self {
4719                name: name.to_string(),
4720                beam_energy: Default::default(),
4721            }
4722            .into_expression()
4723        }
4724    }
4725
4726    #[typetag::serde]
4727    impl Amplitude for CacheOnlyScalar {
4728        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4729            self.beam_energy =
4730                resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
4731            resources.register_amplitude(&self.name)
4732        }
4733
4734        fn dependence_hint(&self) -> ExpressionDependence {
4735            ExpressionDependence::CacheOnly
4736        }
4737
4738        fn real_valued_hint(&self) -> bool {
4739            true
4740        }
4741
4742        fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {
4743            cache.store_scalar(self.beam_energy, event.p4_at(0).e());
4744        }
4745
4746        fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
4747            Complex64::new(cache.get_scalar(self.beam_energy), 0.0)
4748        }
4749    }
4750
4751    #[derive(Clone, Copy)]
4752    enum DeterministicFixtureKind {
4753        Separable,
4754        Partial,
4755        NonSeparable,
4756    }
4757
4758    struct DeterministicFixture {
4759        expression: Expression,
4760        dataset: Arc<Dataset>,
4761        parameters: Vec<f64>,
4762    }
4763
4764    const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
4765    const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
4766
4767    fn deterministic_fixture_dataset() -> Arc<Dataset> {
4768        let metadata = Arc::new(DatasetMetadata::default());
4769        let events = vec![
4770            Arc::new(EventData {
4771                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
4772                aux: vec![],
4773                weight: 0.5,
4774            }),
4775            Arc::new(EventData {
4776                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
4777                aux: vec![],
4778                weight: -1.25,
4779            }),
4780            Arc::new(EventData {
4781                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
4782                aux: vec![],
4783                weight: 2.0,
4784            }),
4785            Arc::new(EventData {
4786                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
4787                aux: vec![],
4788                weight: 0.75,
4789            }),
4790        ];
4791        Arc::new(Dataset::new_with_metadata(events, metadata))
4792    }
4793
4794    fn make_deterministic_fixture(kind: DeterministicFixtureKind) -> DeterministicFixture {
4795        let dataset = deterministic_fixture_dataset();
4796        match kind {
4797            DeterministicFixtureKind::Separable => {
4798                let p1 = ParameterOnlyScalar::new("p1", parameter!("p1"))
4799                    .expect("separable p1 should build");
4800                let p2 = ParameterOnlyScalar::new("p2", parameter!("p2"))
4801                    .expect("separable p2 should build");
4802                let c1 = CacheOnlyScalar::new("c1").expect("separable c1 should build");
4803                let c2 = CacheOnlyScalar::new("c2").expect("separable c2 should build");
4804                DeterministicFixture {
4805                    expression: (&p1 * &c1) + &(&p2 * &c2),
4806                    dataset,
4807                    parameters: vec![0.4, -0.3],
4808                }
4809            }
4810            DeterministicFixtureKind::Partial => {
4811                let p =
4812                    ParameterOnlyScalar::new("p", parameter!("p")).expect("partial p should build");
4813                let c = CacheOnlyScalar::new("c").expect("partial c should build");
4814                let m = TestAmplitude::new("m", parameter!("mr"), parameter!("mi"))
4815                    .expect("partial m should build");
4816                DeterministicFixture {
4817                    expression: (&p * &c) + &m,
4818                    dataset,
4819                    parameters: vec![0.55, 0.2, -0.15],
4820                }
4821            }
4822            DeterministicFixtureKind::NonSeparable => {
4823                let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i"))
4824                    .expect("non-separable m1 should build");
4825                let m2 = TestAmplitude::new("m2", parameter!("m2r"), parameter!("m2i"))
4826                    .expect("non-separable m2 should build");
4827                DeterministicFixture {
4828                    expression: &m1 * &m2,
4829                    dataset,
4830                    parameters: vec![0.25, -0.4, 0.6, 0.1],
4831                }
4832            }
4833        }
4834    }
4835
4836    fn assert_weighted_sum_matches_eventwise_baseline(fixture: &DeterministicFixture) {
4837        let evaluator = fixture
4838            .expression
4839            .load(&fixture.dataset)
4840            .expect("fixture evaluator should load");
4841        let expected_value = evaluator
4842            .evaluate_local(&fixture.parameters)
4843            .expect("evaluation should succeed")
4844            .iter()
4845            .zip(fixture.dataset.weights_local().iter())
4846            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
4847        let expected_gradient = evaluator
4848            .evaluate_gradient_local(&fixture.parameters)
4849            .expect("evaluation should succeed")
4850            .iter()
4851            .zip(fixture.dataset.weights_local().iter())
4852            .fold(
4853                DVector::zeros(fixture.parameters.len()),
4854                |mut accum, (gradient, event)| {
4855                    accum += gradient.map(|value| value.re).scale(*event);
4856                    accum
4857                },
4858            );
4859        let actual_value = evaluator
4860            .evaluate_weighted_value_sum_local(&fixture.parameters)
4861            .expect("evaluation should succeed");
4862        let actual_gradient = evaluator
4863            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4864            .expect("evaluation should succeed");
4865        assert_relative_eq!(
4866            actual_value,
4867            expected_value,
4868            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4869            max_relative = DETERMINISTIC_STRICT_REL_TOL
4870        );
4871        for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
4872            assert_relative_eq!(
4873                *actual_item,
4874                *expected_item,
4875                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4876                max_relative = DETERMINISTIC_STRICT_REL_TOL
4877            );
4878        }
4879    }
4880    fn assert_mixed_normalization_components_match_combined_path(fixture: &DeterministicFixture) {
4881        let evaluator = fixture
4882            .expression
4883            .load(&fixture.dataset)
4884            .expect("fixture evaluator should load");
4885        let state = {
4886            let resources = evaluator.resources.read();
4887            evaluator.ensure_cached_integral_cache_state(&resources)
4888        }
4889        .expect("state should be available");
4890        assert!(
4891            !state.values.is_empty(),
4892            "fixture should exercise cached normalization terms"
4893        );
4894        assert!(
4895            !state.execution_sets.residual_amplitudes.is_empty(),
4896            "fixture should exercise residual normalization amplitudes"
4897        );
4898
4899        let (residual_value_sum, cached_value_sum) = evaluator
4900            .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4901            .expect("evaluation should succeed");
4902        assert!(residual_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4903        assert!(cached_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4904        let combined_value = evaluator
4905            .evaluate_weighted_value_sum_local(&fixture.parameters)
4906            .expect("evaluation should succeed");
4907        assert_relative_eq!(
4908            residual_value_sum + cached_value_sum,
4909            combined_value,
4910            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4911            max_relative = DETERMINISTIC_STRICT_REL_TOL
4912        );
4913
4914        let (residual_gradient_sum, cached_gradient_sum) = evaluator
4915            .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4916            .expect("evaluation should succeed");
4917        let combined_gradient = evaluator
4918            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4919            .expect("evaluation should succeed");
4920        assert!(residual_gradient_sum
4921            .iter()
4922            .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4923        assert!(cached_gradient_sum
4924            .iter()
4925            .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4926        for ((residual_item, cached_item), combined_item) in residual_gradient_sum
4927            .iter()
4928            .zip(cached_gradient_sum.iter())
4929            .zip(combined_gradient.iter())
4930        {
4931            assert_relative_eq!(
4932                residual_item + cached_item,
4933                *combined_item,
4934                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4935                max_relative = DETERMINISTIC_STRICT_REL_TOL
4936            );
4937        }
4938    }
4939
4940    #[test]
4941    fn test_deterministic_fixture_weighted_sums_stable_across_activation_mask_toggle() {
4942        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4943        let evaluator = fixture
4944            .expression
4945            .load(&fixture.dataset)
4946            .expect("fixture evaluator should load");
4947        let original_mask = evaluator.active_mask();
4948
4949        let original_value = evaluator
4950            .evaluate_weighted_value_sum_local(&fixture.parameters)
4951            .expect("evaluation should succeed");
4952
4953        evaluator.isolate_many(&["p", "c"]);
4954        assert_ne!(evaluator.active_mask(), original_mask);
4955
4956        evaluator
4957            .set_active_mask(&original_mask)
4958            .expect("original fixture active mask should restore");
4959        assert_eq!(evaluator.active_mask(), original_mask);
4960        let actual_value = evaluator
4961            .evaluate_weighted_value_sum_local(&fixture.parameters)
4962            .expect("evaluation should succeed");
4963        assert_relative_eq!(
4964            actual_value,
4965            original_value,
4966            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4967            max_relative = DETERMINISTIC_STRICT_REL_TOL
4968        );
4969    }
4970
4971    #[test]
4972    fn test_deterministic_fixtures_match_eventwise_weighted_sums() {
4973        let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4974        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4975        let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4976
4977        assert_weighted_sum_matches_eventwise_baseline(&separable);
4978        assert_weighted_sum_matches_eventwise_baseline(&partial);
4979        assert_weighted_sum_matches_eventwise_baseline(&non_separable);
4980    }
4981    #[test]
4982    fn test_deterministic_fixtures_cover_separable_partial_non_separable_models() {
4983        let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4984        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4985        let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4986
4987        let separable_evaluator = separable
4988            .expression
4989            .load(&separable.dataset)
4990            .expect("separable evaluator should load");
4991        let partial_evaluator = partial
4992            .expression
4993            .load(&partial.dataset)
4994            .expect("partial evaluator should load");
4995        let non_separable_evaluator = non_separable
4996            .expression
4997            .load(&non_separable.dataset)
4998            .expect("non-separable evaluator should load");
4999
5000        assert_eq!(
5001            separable_evaluator
5002                .expression_precomputed_cached_integrals()
5003                .expect("integrals should be computed")
5004                .len(),
5005            2
5006        );
5007        assert_eq!(
5008            partial_evaluator
5009                .expression_precomputed_cached_integrals()
5010                .expect("integrals should be computed")
5011                .len(),
5012            1
5013        );
5014        assert!(non_separable_evaluator
5015            .expression_precomputed_cached_integrals()
5016            .expect("integrals should be computed")
5017            .is_empty());
5018    }
5019    #[test]
5020    fn test_partial_fixture_combined_normalization_components_match_total() {
5021        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5022        assert_mixed_normalization_components_match_combined_path(&partial);
5023    }
5024    #[test]
5025    fn test_non_separable_fixture_normalization_components_stay_residual_only() {
5026        let fixture = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
5027        let evaluator = fixture
5028            .expression
5029            .load(&fixture.dataset)
5030            .expect("fixture evaluator should load");
5031        let resources = evaluator.resources.read();
5032        let state = evaluator
5033            .ensure_cached_integral_cache_state(&resources)
5034            .expect("state should be available");
5035        assert!(state.values.is_empty());
5036
5037        let (residual_value_sum, cached_value_sum) = evaluator
5038            .evaluate_weighted_value_sum_local_components(&fixture.parameters)
5039            .expect("evaluation should succeed");
5040        assert_relative_eq!(
5041            cached_value_sum,
5042            0.0,
5043            epsilon = DETERMINISTIC_STRICT_ABS_TOL
5044        );
5045        assert_relative_eq!(
5046            residual_value_sum,
5047            evaluator
5048                .evaluate_weighted_value_sum_local(&fixture.parameters)
5049                .expect("evaluation should succeed"),
5050            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5051            max_relative = DETERMINISTIC_STRICT_REL_TOL
5052        );
5053
5054        let (residual_gradient_sum, cached_gradient_sum) = evaluator
5055            .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
5056            .expect("evaluation should succeed");
5057        assert!(cached_gradient_sum
5058            .iter()
5059            .all(|value| value.abs() <= DETERMINISTIC_STRICT_ABS_TOL));
5060        let combined_gradient = evaluator
5061            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
5062            .expect("evaluation should succeed");
5063        for (residual_item, combined_item) in
5064            residual_gradient_sum.iter().zip(combined_gradient.iter())
5065        {
5066            assert_relative_eq!(
5067                *residual_item,
5068                *combined_item,
5069                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5070                max_relative = DETERMINISTIC_STRICT_REL_TOL
5071            );
5072        }
5073    }
5074
5075    #[test]
5076    fn test_batch_evaluation() {
5077        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag")).unwrap();
5078        let mut event1 = test_event();
5079        event1.p4s[0].t = 10.0;
5080        let mut event2 = test_event();
5081        event2.p4s[0].t = 11.0;
5082        let mut event3 = test_event();
5083        event3.p4s[0].t = 12.0;
5084        let dataset = Arc::new(Dataset::new_with_metadata(
5085            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5086            Arc::new(DatasetMetadata::default()),
5087        ));
5088        let evaluator = expr.load(&dataset).unwrap();
5089        let result = evaluator
5090            .evaluate_batch(&[1.1, 2.2], &[0, 2])
5091            .expect("evaluation should succeed");
5092        assert_eq!(result.len(), 2);
5093        assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
5094        assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
5095        let result_grad = evaluator
5096            .evaluate_gradient_batch(&[1.1, 2.2], &[0, 2])
5097            .expect("evaluation should succeed");
5098        assert_eq!(result_grad.len(), 2);
5099        assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
5100        assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
5101        assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
5102        assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
5103    }
5104
5105    #[test]
5106    fn test_load_compiles_expression_ir_once() {
5107        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5108            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5109        .norm_sqr();
5110        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5111        let evaluator = expr.load(&dataset).unwrap();
5112        assert!(evaluator.expression_slot_count() > 0);
5113    }
5114    #[test]
5115    fn test_expression_ir_value_matches_lowered_runtime() {
5116        let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5117            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5118            * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
5119        .conj()
5120        .norm_sqr();
5121        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5122        let evaluator = expr.load(&dataset).unwrap();
5123        let resources = evaluator.resources.read();
5124        let parameters = resources
5125            .parameter_map
5126            .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
5127            .expect("parameters should assemble");
5128        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5129        evaluator.fill_amplitude_values(
5130            &mut amplitude_values,
5131            resources.active_indices(),
5132            &parameters,
5133            &resources.caches[0],
5134        );
5135        let mut ir_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
5136        let lowered_runtime = evaluator.lowered_runtime();
5137        let lowered_program = lowered_runtime.value_program();
5138        let mut lowered_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
5139        let lowered_value =
5140            evaluator.evaluate_expression_value_with_scratch(&amplitude_values, &mut ir_slots);
5141        let direct_lowered_value =
5142            lowered_program.evaluate_into(&amplitude_values, &mut lowered_slots);
5143        let ir_value = evaluator
5144            .expression_ir()
5145            .evaluate_into(&amplitude_values, &mut ir_slots);
5146        assert_relative_eq!(lowered_value.re, direct_lowered_value.re);
5147        assert_relative_eq!(lowered_value.im, direct_lowered_value.im);
5148        assert_relative_eq!(lowered_value.re, ir_value.re);
5149        assert_relative_eq!(lowered_value.im, ir_value.im);
5150    }
5151    #[test]
5152    fn test_expression_ir_load_initializes_with_lowered_value_runtime() {
5153        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai"))
5154            .unwrap()
5155            .norm_sqr();
5156        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5157        let evaluator = expr.load(&dataset).unwrap();
5158        let lowered_runtime = evaluator.lowered_runtime();
5159        assert_eq!(
5160            lowered_runtime.value_program().kind(),
5161            lowered::LoweredProgramKind::Value
5162        );
5163        assert_eq!(
5164            lowered_runtime.gradient_program().kind(),
5165            lowered::LoweredProgramKind::Gradient
5166        );
5167        assert_eq!(
5168            lowered_runtime.value_gradient_program().kind(),
5169            lowered::LoweredProgramKind::ValueGradient
5170        );
5171    }
5172    #[test]
5173    fn test_expression_ir_gradient_matches_lowered_runtime() {
5174        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5175            * TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5176        .norm_sqr();
5177        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5178        let evaluator = expr.load(&dataset).unwrap();
5179        let resources = evaluator.resources.read();
5180        let parameters = resources
5181            .parameter_map
5182            .assemble(&[1.0, 0.25, -0.8, 0.5])
5183            .expect("parameters should assemble");
5184        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5185        evaluator.fill_amplitude_values(
5186            &mut amplitude_values,
5187            resources.active_indices(),
5188            &parameters,
5189            &resources.caches[0],
5190        );
5191        let mut active_mask = vec![false; evaluator.amplitudes.len()];
5192        for &index in resources.active_indices() {
5193            active_mask[index] = true;
5194        }
5195        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5196            .map(|_| DVector::zeros(parameters.len()))
5197            .collect::<Vec<_>>();
5198        evaluator.fill_amplitude_gradients(
5199            &mut amplitude_gradients,
5200            &active_mask,
5201            &parameters,
5202            &resources.caches[0],
5203        );
5204        let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
5205        let mut ir_gradient_slots: Vec<DVector<Complex64>> =
5206            (0..evaluator.expression_ir().node_count())
5207                .map(|_| DVector::zeros(parameters.len()))
5208                .collect();
5209        let lowered_runtime = evaluator.lowered_runtime();
5210        let lowered_program = lowered_runtime.gradient_program();
5211        let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
5212        let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
5213            .scratch_slots())
5214            .map(|_| DVector::zeros(parameters.len()))
5215            .collect();
5216        let active_gradient = evaluator.evaluate_expression_gradient_with_scratch(
5217            &amplitude_values,
5218            &amplitude_gradients,
5219            &mut ir_value_slots,
5220            &mut ir_gradient_slots,
5221        );
5222        let ir_gradient = evaluator.expression_ir().evaluate_gradient_into(
5223            &amplitude_values,
5224            &amplitude_gradients,
5225            &mut ir_value_slots,
5226            &mut ir_gradient_slots,
5227        );
5228        let lowered_gradient = lowered_program.evaluate_gradient_into(
5229            &amplitude_values,
5230            &amplitude_gradients,
5231            &mut lowered_value_slots,
5232            &mut lowered_gradient_slots,
5233        );
5234        for (active, lowered) in active_gradient.iter().zip(lowered_gradient.iter()) {
5235            assert_relative_eq!(active.re, lowered.re);
5236            assert_relative_eq!(active.im, lowered.im);
5237        }
5238        for (lowered, ir) in lowered_gradient.iter().zip(ir_gradient.iter()) {
5239            assert_relative_eq!(lowered.re, ir.re);
5240            assert_relative_eq!(lowered.im, ir.im);
5241        }
5242    }
5243    #[test]
5244    fn test_expression_ir_value_gradient_matches_lowered_runtime() {
5245        let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5246            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5247            * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
5248        .norm_sqr();
5249        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5250        let evaluator = expr.load(&dataset).unwrap();
5251        let resources = evaluator.resources.read();
5252        let parameters = resources
5253            .parameter_map
5254            .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
5255            .expect("parameters should assemble");
5256        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5257        evaluator.fill_amplitude_values(
5258            &mut amplitude_values,
5259            resources.active_indices(),
5260            &parameters,
5261            &resources.caches[0],
5262        );
5263        let mut active_mask = vec![false; evaluator.amplitudes.len()];
5264        for &index in resources.active_indices() {
5265            active_mask[index] = true;
5266        }
5267        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5268            .map(|_| DVector::zeros(parameters.len()))
5269            .collect::<Vec<_>>();
5270        evaluator.fill_amplitude_gradients(
5271            &mut amplitude_gradients,
5272            &active_mask,
5273            &parameters,
5274            &resources.caches[0],
5275        );
5276        let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
5277        let mut ir_gradient_slots: Vec<DVector<Complex64>> =
5278            (0..evaluator.expression_ir().node_count())
5279                .map(|_| DVector::zeros(parameters.len()))
5280                .collect();
5281        let lowered_runtime = evaluator.lowered_runtime();
5282        let lowered_program = lowered_runtime.value_gradient_program();
5283        let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
5284        let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
5285            .scratch_slots())
5286            .map(|_| DVector::zeros(parameters.len()))
5287            .collect();
5288
5289        let active_value_gradient = evaluator.evaluate_expression_value_gradient_with_scratch(
5290            &amplitude_values,
5291            &amplitude_gradients,
5292            &mut ir_value_slots,
5293            &mut ir_gradient_slots,
5294        );
5295        let ir_value_gradient = evaluator.expression_ir().evaluate_value_gradient_into(
5296            &amplitude_values,
5297            &amplitude_gradients,
5298            &mut ir_value_slots,
5299            &mut ir_gradient_slots,
5300        );
5301        let lowered_value_gradient = lowered_program.evaluate_value_gradient_into(
5302            &amplitude_values,
5303            &amplitude_gradients,
5304            &mut lowered_value_slots,
5305            &mut lowered_gradient_slots,
5306        );
5307
5308        assert_relative_eq!(active_value_gradient.0.re, lowered_value_gradient.0.re);
5309        assert_relative_eq!(active_value_gradient.0.im, lowered_value_gradient.0.im);
5310        for (active, lowered) in active_value_gradient
5311            .1
5312            .iter()
5313            .zip(lowered_value_gradient.1.iter())
5314        {
5315            assert_relative_eq!(active.re, lowered.re);
5316            assert_relative_eq!(active.im, lowered.im);
5317        }
5318        assert_relative_eq!(lowered_value_gradient.0.re, ir_value_gradient.0.re);
5319        assert_relative_eq!(lowered_value_gradient.0.im, ir_value_gradient.0.im);
5320        for (lowered, ir) in lowered_value_gradient
5321            .1
5322            .iter()
5323            .zip(ir_value_gradient.1.iter())
5324        {
5325            assert_relative_eq!(lowered.re, ir.re);
5326            assert_relative_eq!(lowered.im, ir.im);
5327        }
5328    }
5329    #[test]
5330    fn test_expression_runtime_diagnostics_reports_lowered_programs() {
5331        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5332        let evaluator = fixture
5333            .expression
5334            .load(&fixture.dataset)
5335            .expect("fixture evaluator should load");
5336
5337        let diagnostics = evaluator.expression_runtime_diagnostics();
5338        assert!(diagnostics.ir_planning_enabled);
5339        assert!(diagnostics.lowered_value_program_present);
5340        assert!(diagnostics.lowered_gradient_program_present);
5341        assert!(diagnostics.lowered_value_gradient_program_present);
5342        assert!(diagnostics.residual_runtime_present);
5343        assert_eq!(
5344            diagnostics.specialization_status,
5345            Some(ExpressionSpecializationStatus {
5346                origin: ExpressionSpecializationOrigin::InitialLoad,
5347            })
5348        );
5349    }
5350    #[test]
5351    fn test_expression_runtime_diagnostics_reports_specialization_origin() {
5352        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5353        let evaluator = fixture
5354            .expression
5355            .load(&fixture.dataset)
5356            .expect("fixture evaluator should load");
5357
5358        assert_eq!(
5359            evaluator
5360                .expression_runtime_diagnostics()
5361                .specialization_status,
5362            Some(ExpressionSpecializationStatus {
5363                origin: ExpressionSpecializationOrigin::InitialLoad,
5364            })
5365        );
5366
5367        evaluator.isolate_many(&["p"]);
5368        assert_eq!(
5369            evaluator
5370                .expression_runtime_diagnostics()
5371                .specialization_status,
5372            Some(ExpressionSpecializationStatus {
5373                origin: ExpressionSpecializationOrigin::CacheMissRebuild,
5374            })
5375        );
5376    }
5377    #[test]
5378    fn test_compiled_expression_display_reports_dag_refs() {
5379        let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
5380        let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
5381        let term = &a * &b;
5382        let expr = &term + &term;
5383        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5384        let evaluator = expr.load(&dataset).unwrap();
5385
5386        let compiled = evaluator.compiled_expression();
5387        let display = compiled.to_string();
5388
5389        assert_eq!(compiled.root(), compiled.nodes().len() - 1);
5390        assert!(display.contains("#"));
5391        assert!(display.contains("+"));
5392        assert!(display.contains("×"));
5393        assert!(display.contains("a(id=0)"));
5394        assert!(display.contains("b(id=1)"));
5395        assert!(display.contains("(ref)"));
5396    }
5397
5398    #[test]
5399    fn test_expression_compiled_expression_display_reports_dag_refs_without_loading() {
5400        let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
5401        let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
5402        let term = &a * &b;
5403        let expr = &term + &term;
5404
5405        let compiled = expr.compiled_expression();
5406        let display = compiled.to_string();
5407
5408        assert_eq!(compiled.root(), compiled.nodes().len() - 1);
5409        assert!(display.contains("#"));
5410        assert!(display.contains("+"));
5411        assert!(display.contains("×"));
5412        assert!(display.contains("a(id=0)"));
5413        assert!(display.contains("b(id=1)"));
5414        assert!(display.contains("(ref)"));
5415    }
5416
5417    #[test]
5418    fn test_compiled_expression_display_uses_current_active_mask() {
5419        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5420            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
5421        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5422        let evaluator = expr.load(&dataset).unwrap();
5423        evaluator.deactivate("b");
5424
5425        let compiled = evaluator.compiled_expression().to_string();
5426
5427        assert!(compiled.contains("a(id=0)"));
5428        assert!(!compiled.contains("b(id=1)"));
5429        assert!(!compiled.contains("const 0"));
5430        assert!(!compiled.contains("+"));
5431    }
5432
5433    fn assert_compiled_single_amplitude(expr: &Expression, expected_label: &str) {
5434        let compiled = expr.compiled_expression();
5435        assert_eq!(compiled.nodes().len(), 1);
5436        assert_eq!(compiled.root(), 0);
5437        match &compiled.nodes()[0] {
5438            CompiledExpressionNode::Amplitude { index, name } => {
5439                assert_eq!(*index, 0);
5440                assert_eq!(name, expected_label);
5441            }
5442            node => panic!("expected one amplitude node, got {node:?}"),
5443        }
5444    }
5445
5446    fn assert_compiled_constant(expr: &Expression, expected: Complex64) {
5447        let compiled = expr.compiled_expression();
5448        assert_eq!(compiled.nodes().len(), 1);
5449        assert_eq!(compiled.root(), 0);
5450        match compiled.nodes()[0] {
5451            CompiledExpressionNode::Constant(value) => {
5452                assert_relative_eq!(value.re, expected.re);
5453                assert_relative_eq!(value.im, expected.im);
5454            }
5455            ref node => panic!("expected one constant node, got {node:?}"),
5456        }
5457    }
5458
5459    #[test]
5460    fn test_compiled_expression_simplifies_arithmetic_identities() {
5461        let amp = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
5462        let zero = Expression::zero();
5463        let one = Expression::one();
5464
5465        assert_compiled_single_amplitude(&(&amp + &zero), "a");
5466        assert_compiled_single_amplitude(&(&zero + &amp), "a");
5467        assert_compiled_single_amplitude(&(&amp - &zero), "a");
5468        assert_compiled_single_amplitude(&(&amp * &one), "a");
5469        assert_compiled_single_amplitude(&(&one * &amp), "a");
5470        assert_compiled_single_amplitude(&(&amp / &one), "a");
5471        assert_compiled_single_amplitude(&amp.pow(&one), "a");
5472        assert_compiled_single_amplitude(&amp.powi(1), "a");
5473        assert_compiled_single_amplitude(&amp.powf(1.0), "a");
5474
5475        let times_zero = &amp * &zero;
5476        assert_compiled_constant(&times_zero, Complex64::ZERO);
5477        assert!(times_zero.parameters().contains_key("ar"));
5478        assert!(times_zero.parameters().contains_key("ai"));
5479
5480        assert_compiled_constant(&(&zero * &amp), Complex64::ZERO);
5481        assert_compiled_constant(&(&zero / &Expression::from(2.0)), Complex64::ZERO);
5482        assert_compiled_constant(&amp.powi(0), Complex64::ONE);
5483        assert_compiled_constant(
5484            &Expression::from(2.0).pow(&Expression::zero()),
5485            Complex64::ONE,
5486        );
5487        assert_compiled_constant(&Expression::from(2.0).powf(0.0), Complex64::ONE);
5488
5489        let unsafe_zero_division = (&zero / &amp).compiled_expression().to_string();
5490        assert!(unsafe_zero_division.contains("÷"));
5491        assert!(unsafe_zero_division.contains("a(id=0)"));
5492    }
5493
5494    #[test]
5495    fn test_compiled_expression_folds_unary_constant_functions() {
5496        assert_compiled_constant(&Expression::from(0.0).exp(), Complex64::ONE);
5497        assert_compiled_constant(&Expression::from(0.0).sin(), Complex64::ZERO);
5498        assert_compiled_constant(&Expression::from(0.0).cos(), Complex64::ONE);
5499        assert_compiled_constant(&Expression::from(1.0).log(), Complex64::ZERO);
5500        assert_compiled_constant(&Expression::from(4.0).sqrt(), Complex64::new(2.0, 0.0));
5501        assert_compiled_constant(&Expression::from(0.0).cis(), Complex64::ONE);
5502    }
5503
5504    #[test]
5505    fn test_evaluator_expression_reconstructs_expression() {
5506        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
5507        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5508        let evaluator = expr.load(&dataset).unwrap();
5509
5510        assert_eq!(
5511            evaluator.expression().compiled_expression(),
5512            expr.compiled_expression()
5513        );
5514    }
5515
5516    #[test]
5517    fn test_active_mask_override_ignores_current_ir_specialization() {
5518        let expr = ComplexScalar::new("amp", parameter!("scale"), parameter!("amp_im", 0.0))
5519            .unwrap()
5520            .norm_sqr();
5521        let dataset = Arc::new(test_dataset());
5522        let evaluator = expr.load(&dataset).unwrap();
5523        let params = vec![2.0];
5524
5525        evaluator.deactivate("amp");
5526        assert_eq!(
5527            evaluator
5528                .evaluate(&params)
5529                .expect("evaluation should succeed")[0],
5530            Complex64::new(0.0, 0.0)
5531        );
5532
5533        let overridden = evaluator
5534            .evaluate_local_with_active_mask(&params, &[true])
5535            .unwrap();
5536        assert_eq!(overridden[0], Complex64::new(4.0, 0.0));
5537
5538        let overridden_fused = evaluator
5539            .evaluate_with_gradient_local_with_active_mask(&params, &[true])
5540            .unwrap();
5541        assert_eq!(overridden_fused[0].0, Complex64::new(4.0, 0.0));
5542        assert_eq!(overridden_fused[0].1[0], Complex64::new(4.0, 0.0));
5543    }
5544    #[test]
5545    fn test_expression_ir_dependence_diagnostics_surface() {
5546        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5547            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5548        .norm_sqr();
5549        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5550        let evaluator = expr.load(&dataset).unwrap();
5551        let annotations = evaluator
5552            .expression_node_dependence_annotations()
5553            .expect("annotations should exist");
5554        assert_eq!(annotations.len(), evaluator.expression_ir().node_count());
5555        assert!(annotations
5556            .iter()
5557            .all(|dependence| *dependence == ExpressionDependence::Mixed));
5558        assert_eq!(
5559            evaluator
5560                .expression_root_dependence()
5561                .expect("root dependence should exist"),
5562            ExpressionDependence::Mixed
5563        );
5564    }
5565    #[test]
5566    fn test_expression_ir_default_dependence_hint_is_mixed() {
5567        let expr = ComplexScalar::new("c", parameter!("cr"), parameter!("ci")).unwrap();
5568        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5569        let evaluator = expr.load(&dataset).unwrap();
5570        assert_eq!(
5571            evaluator
5572                .expression_root_dependence()
5573                .expect("root dependence should exist"),
5574            ExpressionDependence::Mixed
5575        );
5576    }
5577    #[test]
5578    fn test_expression_ir_parameter_only_dependence_hint_propagates() {
5579        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap();
5580        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5581        let evaluator = expr.load(&dataset).unwrap();
5582        assert_eq!(
5583            evaluator
5584                .expression_root_dependence()
5585                .expect("root dependence should exist"),
5586            ExpressionDependence::ParameterOnly
5587        );
5588    }
5589    #[test]
5590    fn test_expression_ir_cache_only_dependence_hint_propagates() {
5591        let expr = CacheOnlyScalar::new("k").unwrap();
5592        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5593        let evaluator = expr.load(&dataset).unwrap();
5594        assert_eq!(
5595            evaluator
5596                .expression_root_dependence()
5597                .expect("root dependence should exist"),
5598            ExpressionDependence::CacheOnly
5599        );
5600    }
5601    #[test]
5602    fn test_expression_ir_real_valued_hint_folds_imag_projection_to_zero() {
5603        let expr = ParameterOnlyScalar::new("p", parameter!("p"))
5604            .unwrap()
5605            .imag();
5606        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5607        let evaluator = expr.load(&dataset).unwrap();
5608        let ir = evaluator.expression_ir();
5609
5610        assert!(matches!(
5611            ir.nodes()[ir.root()],
5612            ir::IrNode::Constant(value) if value == Complex64::ZERO
5613        ));
5614        assert_eq!(
5615            evaluator
5616                .evaluate(&[2.5])
5617                .expect("evaluation should succeed")[0],
5618            Complex64::ZERO
5619        );
5620    }
5621    #[test]
5622    fn test_expression_ir_real_valued_hint_simplifies_conjugation() {
5623        let expr = ParameterOnlyScalar::new("p", parameter!("p"))
5624            .unwrap()
5625            .conj();
5626        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5627        let evaluator = expr.load(&dataset).unwrap();
5628        let ir = evaluator.expression_ir();
5629
5630        assert!(matches!(ir.nodes()[ir.root()], ir::IrNode::Amp(0)));
5631        assert_eq!(
5632            evaluator
5633                .evaluate(&[2.5])
5634                .expect("evaluation should succeed")[0],
5635            Complex64::new(2.5, 0.0)
5636        );
5637    }
5638    #[test]
5639    fn test_expression_ir_dependence_warnings_surface() {
5640        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5641            + &CacheOnlyScalar::new("k").unwrap();
5642        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5643        let evaluator = expr.load(&dataset).unwrap();
5644        assert!(evaluator
5645            .expression_dependence_warnings()
5646            .expect("warnings should exist")
5647            .iter()
5648            .any(|warning| warning.contains("both ParameterOnly and CacheOnly")));
5649    }
5650    #[test]
5651    fn test_expression_ir_normalization_plan_explain_surface() {
5652        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5653            * &CacheOnlyScalar::new("k").unwrap();
5654        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5655        let evaluator = expr.load(&dataset).unwrap();
5656        let explain = evaluator
5657            .expression_normalization_plan_explain()
5658            .expect("plan should exist");
5659        assert_eq!(explain.root_dependence, ExpressionDependence::Mixed);
5660        assert_eq!(explain.separable_mul_candidate_nodes.len(), 1);
5661        assert_eq!(
5662            explain.cached_separable_nodes,
5663            explain.separable_mul_candidate_nodes
5664        );
5665        assert!(explain.residual_terms.iter().all(|index| {
5666            !explain
5667                .separable_mul_candidate_nodes
5668                .iter()
5669                .any(|candidate| candidate == index)
5670        }));
5671    }
5672    #[test]
5673    fn test_expression_ir_normalization_execution_sets_surface() {
5674        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5675            * &CacheOnlyScalar::new("k").unwrap();
5676        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5677        let evaluator = expr.load(&dataset).unwrap();
5678        let sets = evaluator
5679            .expression_normalization_execution_sets()
5680            .expect("sets should exist");
5681        assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5682        assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5683        assert!(sets.residual_amplitudes.is_empty());
5684    }
5685    #[test]
5686    fn test_expression_ir_normalization_execution_sets_partial_surface() {
5687        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5688            * &CacheOnlyScalar::new("k").unwrap())
5689            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5690        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5691        let evaluator = expr.load(&dataset).unwrap();
5692        let sets = evaluator
5693            .expression_normalization_execution_sets()
5694            .expect("sets should exist");
5695        assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5696        assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5697        assert_eq!(sets.residual_amplitudes, vec![2]);
5698    }
5699    #[test]
5700    fn test_expression_ir_precomputed_cached_integrals_at_load() {
5701        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5702            * &CacheOnlyScalar::new("k").unwrap();
5703        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5704        let evaluator = expr.load(&dataset).unwrap();
5705        let precomputed = evaluator
5706            .expression_precomputed_cached_integrals()
5707            .expect("integrals should exist");
5708        assert_eq!(precomputed.len(), 1);
5709        let cache_reference = CacheOnlyScalar::new("k_ref")
5710            .unwrap()
5711            .load(&dataset)
5712            .unwrap();
5713        let cache_values = cache_reference
5714            .evaluate_local(&[])
5715            .expect("evaluation should succeed");
5716        let expected_weighted_sum = cache_values
5717            .iter()
5718            .zip(dataset.weights_local().iter())
5719            .fold(Complex64::ZERO, |acc, (value, event)| {
5720                acc + (*value * *event)
5721            });
5722        assert_relative_eq!(
5723            precomputed[0].weighted_cache_sum.re,
5724            expected_weighted_sum.re
5725        );
5726        assert_relative_eq!(
5727            precomputed[0].weighted_cache_sum.im,
5728            expected_weighted_sum.im
5729        );
5730    }
5731    #[test]
5732    fn test_expression_ir_precomputed_cached_integrals_empty_when_non_separable() {
5733        let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5734            * &CacheOnlyScalar::new("k").unwrap();
5735        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5736        let evaluator = expr.load(&dataset).unwrap();
5737        assert!(evaluator
5738            .expression_precomputed_cached_integrals()
5739            .expect("integrals should exist")
5740            .is_empty());
5741    }
5742    #[test]
5743    fn test_expression_ir_precomputed_cached_integrals_recompute_on_activation_change() {
5744        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5745            * &CacheOnlyScalar::new("k").unwrap();
5746        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5747        let evaluator = expr.load(&dataset).unwrap();
5748        assert_eq!(
5749            evaluator
5750                .expression_precomputed_cached_integrals()
5751                .expect("integrals should exist")
5752                .len(),
5753            1
5754        );
5755
5756        evaluator.isolate_many(&["p"]);
5757        assert!(evaluator
5758            .expression_precomputed_cached_integrals()
5759            .expect("integrals should exist")
5760            .is_empty());
5761    }
5762    #[test]
5763    fn test_expression_ir_precomputed_cached_integrals_recompute_on_dataset_change() {
5764        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5765            * &CacheOnlyScalar::new("k").unwrap();
5766        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5767        let mut evaluator = expr.load(&dataset).unwrap();
5768        drop(dataset);
5769        let before = evaluator
5770            .expression_precomputed_cached_integrals()
5771            .expect("integrals should exist");
5772        assert_eq!(before.len(), 1);
5773
5774        Arc::get_mut(&mut evaluator.dataset)
5775            .expect("evaluator should own dataset Arc in this test")
5776            .clear_events_local();
5777        let after = evaluator
5778            .expression_precomputed_cached_integrals()
5779            .expect("integrals should exist");
5780        assert_eq!(after.len(), 1);
5781        assert_eq!(after[0].weighted_cache_sum, Complex64::ZERO);
5782        assert!(before[0].weighted_cache_sum != after[0].weighted_cache_sum);
5783    }
5784    #[test]
5785    fn test_expression_ir_precomputed_cached_integral_gradient_terms_scale_by_cache_integrals() {
5786        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5787            * &CacheOnlyScalar::new("k").unwrap();
5788        let dataset = Arc::new(Dataset::new(vec![
5789            Arc::new(test_event()),
5790            Arc::new(test_event()),
5791        ]));
5792        let evaluator = expr.load(&dataset).unwrap();
5793        let cached_integrals = evaluator
5794            .expression_precomputed_cached_integrals()
5795            .expect("integrals should exist");
5796        assert_eq!(cached_integrals.len(), 1);
5797        let gradient_terms = evaluator
5798            .expression_precomputed_cached_integral_gradient_terms(&[1.25])
5799            .expect("evaluation should succeed");
5800        assert_eq!(gradient_terms.len(), 1);
5801        assert_eq!(gradient_terms[0].weighted_gradient.len(), 1);
5802        assert_relative_eq!(
5803            gradient_terms[0].weighted_gradient[0].re,
5804            cached_integrals[0].weighted_cache_sum.re,
5805            epsilon = 1e-6
5806        );
5807        assert_relative_eq!(
5808            gradient_terms[0].weighted_gradient[0].im,
5809            cached_integrals[0].weighted_cache_sum.im,
5810            epsilon = 1e-6
5811        );
5812    }
5813    #[test]
5814    fn test_expression_ir_precomputed_cached_integral_gradient_terms_empty_when_not_separable() {
5815        let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5816            * &CacheOnlyScalar::new("k").unwrap();
5817        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5818        let evaluator = expr.load(&dataset).unwrap();
5819        assert!(evaluator
5820            .expression_precomputed_cached_integral_gradient_terms(&[0.1, -0.2])
5821            .expect("evaluation should succeed")
5822            .is_empty());
5823    }
5824    #[test]
5825    fn test_expression_ir_lowered_cached_factor_programs_match_ir_cached_paths() {
5826        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5827            * &CacheOnlyScalar::new("k").unwrap())
5828            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5829        let dataset = Arc::new(test_dataset());
5830        let evaluator = expr.load(&dataset).unwrap();
5831        let resources = evaluator.resources.read();
5832        let state = evaluator
5833            .ensure_cached_integral_cache_state(&resources)
5834            .expect("state should be available");
5835        let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5836        let parameters = resources
5837            .parameter_map
5838            .assemble(&[0.55, 0.2, -0.15])
5839            .expect("parameters should assemble");
5840
5841        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5842        evaluator.fill_amplitude_values(
5843            &mut amplitude_values,
5844            &state.execution_sets.cached_parameter_amplitudes,
5845            &parameters,
5846            &resources.caches[0],
5847        );
5848        let cached_value_ir =
5849            evaluator.evaluate_cached_weighted_value_sum_ir(&state, &amplitude_values);
5850        let cached_value_lowered = evaluator
5851            .evaluate_cached_weighted_value_sum_lowered(
5852                &state,
5853                lowered_artifacts.as_ref(),
5854                &amplitude_values,
5855            )
5856            .expect("cached value lowering should succeed");
5857        assert_relative_eq!(cached_value_lowered, cached_value_ir, epsilon = 1e-12);
5858
5859        let mut cached_parameter_mask = vec![false; evaluator.amplitudes.len()];
5860        for &index in &state.execution_sets.cached_parameter_amplitudes {
5861            cached_parameter_mask[index] = true;
5862        }
5863        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5864            .map(|_| DVector::zeros(parameters.len()))
5865            .collect::<Vec<_>>();
5866        evaluator.fill_amplitude_gradients(
5867            &mut amplitude_gradients,
5868            &cached_parameter_mask,
5869            &parameters,
5870            &resources.caches[0],
5871        );
5872        let cached_gradient_ir = evaluator.evaluate_cached_weighted_gradient_sum_ir(
5873            &state,
5874            &amplitude_values,
5875            &amplitude_gradients,
5876            parameters.len(),
5877        );
5878        let cached_gradient_lowered = evaluator
5879            .evaluate_cached_weighted_gradient_sum_lowered(
5880                &state,
5881                lowered_artifacts.as_ref(),
5882                &amplitude_values,
5883                &amplitude_gradients,
5884                parameters.len(),
5885            )
5886            .expect("cached gradient lowering should succeed");
5887        for (lowered, ir) in cached_gradient_lowered
5888            .iter()
5889            .zip(cached_gradient_ir.iter())
5890        {
5891            assert_relative_eq!(*lowered, *ir, epsilon = 1e-12);
5892        }
5893    }
5894    #[test]
5895    fn test_expression_ir_lowered_residual_runtime_matches_zeroed_node_path() {
5896        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5897            * &CacheOnlyScalar::new("k").unwrap())
5898            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5899        let dataset = Arc::new(test_dataset());
5900        let evaluator = expr.load(&dataset).unwrap();
5901        let resources = evaluator.resources.read();
5902        let state = evaluator
5903            .ensure_cached_integral_cache_state(&resources)
5904            .expect("state should be available");
5905        let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5906        let parameters = resources
5907            .parameter_map
5908            .assemble(&[0.55, 0.2, -0.15])
5909            .expect("parameters should assemble");
5910
5911        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5912        evaluator.fill_amplitude_values(
5913            &mut amplitude_values,
5914            &state.execution_sets.residual_amplitudes,
5915            &parameters,
5916            &resources.caches[0],
5917        );
5918        let residual_value_ir = evaluator.evaluate_residual_value_ir(&state, &amplitude_values);
5919        let residual_program = lowered_artifacts
5920            .residual_runtime
5921            .as_ref()
5922            .map(|runtime| runtime.value_program())
5923            .expect("residual value lowering should succeed");
5924        let mut value_slots = vec![Complex64::ZERO; residual_program.scratch_slots()];
5925        let residual_value_lowered =
5926            residual_program.evaluate_into(&amplitude_values, &mut value_slots);
5927        assert_relative_eq!(
5928            residual_value_lowered.re,
5929            residual_value_ir.re,
5930            epsilon = 1e-12
5931        );
5932        assert_relative_eq!(
5933            residual_value_lowered.im,
5934            residual_value_ir.im,
5935            epsilon = 1e-12
5936        );
5937
5938        let mut residual_active_mask = vec![false; evaluator.amplitudes.len()];
5939        for &index in &state.execution_sets.residual_amplitudes {
5940            residual_active_mask[index] = true;
5941        }
5942        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5943            .map(|_| DVector::zeros(parameters.len()))
5944            .collect::<Vec<_>>();
5945        evaluator.fill_amplitude_gradients(
5946            &mut amplitude_gradients,
5947            &residual_active_mask,
5948            &parameters,
5949            &resources.caches[0],
5950        );
5951        let residual_gradient_ir = evaluator.evaluate_residual_gradient_ir(
5952            &state,
5953            &amplitude_values,
5954            &amplitude_gradients,
5955            parameters.len(),
5956        );
5957
5958        let program = lowered_artifacts
5959            .residual_runtime
5960            .as_ref()
5961            .map(|runtime| runtime.gradient_program())
5962            .expect("gradient lowering should succeed");
5963        let mut value_slots = vec![Complex64::ZERO; program.scratch_slots()];
5964        let mut gradient_slots = vec![Complex64::ZERO; program.scratch_slots() * parameters.len()];
5965        let residual_gradient_lowered = program.evaluate_gradient_into_flat(
5966            &amplitude_values,
5967            &amplitude_gradients,
5968            &mut value_slots,
5969            &mut gradient_slots,
5970            parameters.len(),
5971        );
5972
5973        for (lowered, ir) in residual_gradient_lowered
5974            .iter()
5975            .zip(residual_gradient_ir.iter())
5976        {
5977            assert_relative_eq!(lowered.re, ir.re, epsilon = 1e-12);
5978            assert_relative_eq!(lowered.im, ir.im, epsilon = 1e-12);
5979        }
5980    }
5981    #[test]
5982    fn test_expression_ir_reuses_lowered_artifacts_when_dataset_key_changes() {
5983        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5984            * &CacheOnlyScalar::new("k").unwrap())
5985            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5986        let dataset = Arc::new(test_dataset());
5987        let mut evaluator = expr.load(&dataset).unwrap();
5988        drop(dataset);
5989
5990        assert_eq!(evaluator.specialization_cache_len(), 1);
5991        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5992
5993        evaluator.reset_expression_compile_metrics();
5994        evaluator.reset_expression_specialization_metrics();
5995
5996        Arc::get_mut(&mut evaluator.dataset)
5997            .expect("evaluator should own dataset Arc in this test")
5998            .clear_events_local();
5999
6000        let cached_integrals = evaluator
6001            .expression_precomputed_cached_integrals()
6002            .expect("integrals should exist");
6003        assert_eq!(cached_integrals.len(), 1);
6004        assert_eq!(cached_integrals[0].weighted_cache_sum, Complex64::ZERO);
6005
6006        assert_eq!(evaluator.specialization_cache_len(), 2);
6007        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
6008        assert_eq!(
6009            evaluator.expression_specialization_metrics(),
6010            ExpressionSpecializationMetrics {
6011                cache_hits: 0,
6012                cache_misses: 1,
6013            }
6014        );
6015
6016        let compile_metrics = evaluator.expression_compile_metrics();
6017        assert_eq!(compile_metrics.specialization_cache_hits, 0);
6018        assert_eq!(compile_metrics.specialization_cache_misses, 1);
6019        assert_eq!(compile_metrics.specialization_lowering_cache_hits, 1);
6020        assert_eq!(compile_metrics.specialization_lowering_cache_misses, 0);
6021        assert!(compile_metrics.specialization_ir_compile_nanos > 0);
6022        assert!(compile_metrics.specialization_cached_integrals_nanos > 0);
6023        assert_eq!(compile_metrics.specialization_lowering_nanos, 0);
6024    }
6025
6026    #[test]
6027    fn test_evaluate_weighted_gradient_sum_local_matches_eventwise_baseline() {
6028        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6029        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6030        let c1 = CacheOnlyScalar::new("c1").unwrap();
6031        let c2 = CacheOnlyScalar::new("c2").unwrap();
6032        let c3 = CacheOnlyScalar::new("c3").unwrap();
6033        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6034        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6035        let dataset = Arc::new(test_dataset());
6036        let evaluator = expr.load(&dataset).unwrap();
6037        assert_eq!(
6038            evaluator
6039                .expression_precomputed_cached_integrals()
6040                .expect("integrals should exist")
6041                .len(),
6042            2
6043        );
6044        let params = vec![0.2, -0.3, 1.1, -0.7];
6045        let expected = evaluator
6046            .evaluate_gradient_local(&params)
6047            .expect("evaluation should succeed")
6048            .iter()
6049            .zip(dataset.weights_local().iter())
6050            .fold(
6051                DVector::zeros(params.len()),
6052                |mut accum, (gradient, event)| {
6053                    accum += gradient.map(|value| value.re).scale(*event);
6054                    accum
6055                },
6056            );
6057        let actual = evaluator
6058            .evaluate_weighted_gradient_sum_local(&params)
6059            .expect("evaluation should succeed");
6060        for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
6061            assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
6062        }
6063    }
6064
6065    #[test]
6066    fn test_evaluate_weighted_value_sum_local_matches_eventwise_baseline() {
6067        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6068        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6069        let c1 = CacheOnlyScalar::new("c1").unwrap();
6070        let c2 = CacheOnlyScalar::new("c2").unwrap();
6071        let c3 = CacheOnlyScalar::new("c3").unwrap();
6072        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6073        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6074        let dataset = Arc::new(test_dataset());
6075        let evaluator = expr.load(&dataset).unwrap();
6076        assert_eq!(
6077            evaluator
6078                .expression_precomputed_cached_integrals()
6079                .expect("integrals should exist")
6080                .len(),
6081            2
6082        );
6083        let params = vec![0.2, -0.3, 1.1, -0.7];
6084        let expected = evaluator
6085            .evaluate_local(&params)
6086            .expect("evaluation should succeed")
6087            .iter()
6088            .zip(dataset.weights_local().iter())
6089            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6090        let actual = evaluator
6091            .evaluate_weighted_value_sum_local(&params)
6092            .expect("evaluation should succeed");
6093        assert_relative_eq!(actual, expected, epsilon = 1e-10);
6094    }
6095
6096    #[test]
6097    fn test_weighted_sums_match_hardcoded_reference_values() {
6098        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6099        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6100        let c1 = CacheOnlyScalar::new("c1").unwrap();
6101        let c2 = CacheOnlyScalar::new("c2").unwrap();
6102        let c3 = CacheOnlyScalar::new("c3").unwrap();
6103        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6104        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6105
6106        let metadata = Arc::new(DatasetMetadata::default());
6107        let dataset = Arc::new(Dataset::new_with_metadata(
6108            vec![
6109                Arc::new(EventData {
6110                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6111                    aux: vec![],
6112                    weight: 0.5,
6113                }),
6114                Arc::new(EventData {
6115                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6116                    aux: vec![],
6117                    weight: -1.25,
6118                }),
6119                Arc::new(EventData {
6120                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6121                    aux: vec![],
6122                    weight: 2.0,
6123                }),
6124            ],
6125            metadata,
6126        ));
6127        let evaluator = expr.load(&dataset).unwrap();
6128        let params = vec![0.7, -1.1, 0.9, -0.4];
6129
6130        let weighted_value_sum = evaluator
6131            .evaluate_weighted_value_sum_local(&params)
6132            .expect("evaluation should succeed");
6133        assert_relative_eq!(weighted_value_sum, 22.7725, epsilon = 1e-12);
6134
6135        let weighted_gradient_sum = evaluator
6136            .evaluate_weighted_gradient_sum_local(&params)
6137            .expect("evaluation should succeed");
6138        let free_parameters = evaluator
6139            .parameters()
6140            .free()
6141            .names()
6142            .into_iter()
6143            .map(|name| name.to_string())
6144            .collect::<Vec<_>>();
6145        assert_eq!(free_parameters, vec!["p1", "p2", "m1r", "m1i"]);
6146        let expected_gradient = [43.925, 7.25, 28.525, 0.0];
6147        assert_eq!(weighted_gradient_sum.len(), expected_gradient.len());
6148        for (actual, expected) in weighted_gradient_sum.iter().zip(expected_gradient.iter()) {
6149            assert_relative_eq!(*actual, *expected, epsilon = 1e-9);
6150        }
6151    }
6152    #[test]
6153    fn test_evaluate_weighted_gradient_sum_local_respects_signed_cached_terms() {
6154        let expr = Expression::one()
6155            - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6156                * &CacheOnlyScalar::new("k").unwrap());
6157        let dataset = Arc::new(test_dataset());
6158        let evaluator = expr.load(&dataset).unwrap();
6159        assert_eq!(
6160            evaluator
6161                .expression_precomputed_cached_integrals()
6162                .expect("integrals should exist")
6163                .len(),
6164            1
6165        );
6166        assert_eq!(
6167            evaluator
6168                .expression_precomputed_cached_integrals()
6169                .expect("integrals should exist")[0]
6170                .coefficient,
6171            -1
6172        );
6173        let params = vec![0.75];
6174        let expected = evaluator
6175            .evaluate_gradient_local(&params)
6176            .expect("evaluation should succeed")
6177            .iter()
6178            .zip(dataset.weights_local().iter())
6179            .fold(
6180                DVector::zeros(params.len()),
6181                |mut accum, (gradient, event)| {
6182                    accum += gradient.map(|value| value.re).scale(*event);
6183                    accum
6184                },
6185            );
6186        let actual = evaluator
6187            .evaluate_weighted_gradient_sum_local(&params)
6188            .expect("evaluation should succeed");
6189        for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
6190            assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
6191        }
6192    }
6193    #[test]
6194    fn test_evaluate_weighted_value_sum_local_respects_signed_cached_terms() {
6195        let expr = Expression::one()
6196            - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6197                * &CacheOnlyScalar::new("k").unwrap());
6198        let dataset = Arc::new(test_dataset());
6199        let evaluator = expr.load(&dataset).unwrap();
6200        assert_eq!(
6201            evaluator
6202                .expression_precomputed_cached_integrals()
6203                .expect("integrals should exist")
6204                .len(),
6205            1
6206        );
6207        assert_eq!(
6208            evaluator
6209                .expression_precomputed_cached_integrals()
6210                .expect("integrals should exist")[0]
6211                .coefficient,
6212            -1
6213        );
6214        let params = vec![0.75];
6215        let expected = evaluator
6216            .evaluate_local(&params)
6217            .expect("evaluation should succeed")
6218            .iter()
6219            .zip(dataset.weights_local().iter())
6220            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6221        let actual = evaluator
6222            .evaluate_weighted_value_sum_local(&params)
6223            .expect("evaluation should succeed");
6224        assert_relative_eq!(actual, expected, epsilon = 1e-10);
6225    }
6226    #[test]
6227    fn test_expression_ir_diagnostics_follow_activation_changes() {
6228        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6229            * &CacheOnlyScalar::new("k").unwrap())
6230            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6231        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6232        let evaluator = expr.load(&dataset).unwrap();
6233
6234        let all_active = evaluator
6235            .expression_normalization_plan_explain()
6236            .expect("plan should exist");
6237        assert_eq!(all_active.cached_separable_nodes.len(), 1);
6238        assert_eq!(
6239            evaluator
6240                .expression_root_dependence()
6241                .expect("root dependence should exist"),
6242            ExpressionDependence::Mixed
6243        );
6244
6245        evaluator.isolate_many(&["p"]);
6246        let param_only = evaluator
6247            .expression_normalization_plan_explain()
6248            .expect("plan should exist");
6249        assert!(param_only.cached_separable_nodes.is_empty());
6250        assert_eq!(
6251            evaluator
6252                .expression_root_dependence()
6253                .expect("root dependence should exist"),
6254            ExpressionDependence::ParameterOnly
6255        );
6256    }
6257    #[test]
6258    fn test_expression_ir_specialization_cache_reuses_prior_mask_specializations() {
6259        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6260            * &CacheOnlyScalar::new("k").unwrap())
6261            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6262        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6263        let evaluator = expr.load(&dataset).unwrap();
6264
6265        let initial_compile_metrics = evaluator.expression_compile_metrics();
6266        assert!(initial_compile_metrics.initial_ir_compile_nanos > 0);
6267        assert!(initial_compile_metrics.initial_cached_integrals_nanos > 0);
6268        assert!(initial_compile_metrics.initial_lowering_nanos > 0);
6269        assert_eq!(initial_compile_metrics.specialization_cache_hits, 0);
6270        assert_eq!(initial_compile_metrics.specialization_cache_misses, 0);
6271        assert_eq!(
6272            initial_compile_metrics.specialization_lowering_cache_hits,
6273            0
6274        );
6275        assert_eq!(
6276            initial_compile_metrics.specialization_lowering_cache_misses,
6277            1
6278        );
6279
6280        assert_eq!(evaluator.specialization_cache_len(), 1);
6281        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
6282        assert_eq!(
6283            evaluator.expression_specialization_metrics(),
6284            ExpressionSpecializationMetrics {
6285                cache_hits: 0,
6286                cache_misses: 1,
6287            }
6288        );
6289        let all_active_cached_integrals = evaluator
6290            .expression_precomputed_cached_integrals()
6291            .expect("integrals should exist");
6292
6293        evaluator.isolate_many(&["p"]);
6294        assert_eq!(evaluator.specialization_cache_len(), 2);
6295        assert_eq!(
6296            evaluator.expression_specialization_metrics(),
6297            ExpressionSpecializationMetrics {
6298                cache_hits: 0,
6299                cache_misses: 2,
6300            }
6301        );
6302        let after_cache_miss_metrics = evaluator.expression_compile_metrics();
6303        assert_eq!(after_cache_miss_metrics.specialization_cache_hits, 0);
6304        assert_eq!(after_cache_miss_metrics.specialization_cache_misses, 1);
6305        assert_eq!(
6306            after_cache_miss_metrics.specialization_lowering_cache_hits,
6307            0
6308        );
6309        assert_eq!(
6310            after_cache_miss_metrics.specialization_lowering_cache_misses,
6311            2
6312        );
6313        assert!(after_cache_miss_metrics.specialization_ir_compile_nanos > 0);
6314        assert!(after_cache_miss_metrics.specialization_cached_integrals_nanos > 0);
6315        assert!(after_cache_miss_metrics.specialization_lowering_nanos > 0);
6316        assert!(evaluator
6317            .expression_precomputed_cached_integrals()
6318            .expect("integrals should exist")
6319            .is_empty());
6320
6321        evaluator.activate_many(&["k", "m"]);
6322        assert_eq!(evaluator.specialization_cache_len(), 2);
6323        assert_eq!(
6324            evaluator.expression_specialization_metrics(),
6325            ExpressionSpecializationMetrics {
6326                cache_hits: 1,
6327                cache_misses: 2,
6328            }
6329        );
6330        assert_eq!(
6331            evaluator
6332                .expression_precomputed_cached_integrals()
6333                .expect("integrals should exist"),
6334            all_active_cached_integrals
6335        );
6336        let after_cache_hit_metrics = evaluator.expression_compile_metrics();
6337        assert_eq!(after_cache_hit_metrics.specialization_cache_hits, 1);
6338        assert_eq!(after_cache_hit_metrics.specialization_cache_misses, 1);
6339        assert_eq!(
6340            after_cache_hit_metrics.specialization_lowering_cache_hits,
6341            0
6342        );
6343        assert_eq!(
6344            after_cache_hit_metrics.specialization_lowering_cache_misses,
6345            2
6346        );
6347        assert!(after_cache_hit_metrics.specialization_cache_restore_nanos > 0);
6348    }
6349
6350    #[test]
6351    fn test_weighted_sums_match_baseline_after_activation_changes() {
6352        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6353        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6354        let c1 = CacheOnlyScalar::new("c1").unwrap();
6355        let c2 = CacheOnlyScalar::new("c2").unwrap();
6356        let c3 = CacheOnlyScalar::new("c3").unwrap();
6357        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6358        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6359        let dataset = Arc::new(test_dataset());
6360        let evaluator = expr.load(&dataset).unwrap();
6361        let params = vec![0.2, -0.3, 1.1, -0.7];
6362
6363        evaluator.isolate_many(&["p1", "c1", "m1", "c3"]);
6364
6365        let expected_value = evaluator
6366            .evaluate_local(&params)
6367            .expect("evaluation should succeed")
6368            .iter()
6369            .zip(dataset.weights_local().iter())
6370            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6371        assert_relative_eq!(
6372            evaluator
6373                .evaluate_weighted_value_sum_local(&params)
6374                .expect("evaluation should succeed"),
6375            expected_value,
6376            epsilon = 1e-10
6377        );
6378    }
6379
6380    #[test]
6381    fn test_evaluate_local_does_not_depend_on_dataset_rows() {
6382        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
6383            .unwrap()
6384            .norm_sqr();
6385        let mut event1 = test_event();
6386        event1.p4s[0].t = 7.5;
6387        let mut event2 = test_event();
6388        event2.p4s[0].t = 8.25;
6389        let mut event3 = test_event();
6390        event3.p4s[0].t = 9.0;
6391        let dataset = Arc::new(Dataset::new_with_metadata(
6392            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
6393            Arc::new(DatasetMetadata::default()),
6394        ));
6395        let mut evaluator = expr.load(&dataset).unwrap();
6396        drop(dataset);
6397        let expected_len = evaluator.resources.read().caches.len();
6398        Arc::get_mut(&mut evaluator.dataset)
6399            .expect("evaluator should own dataset Arc in this test")
6400            .clear_events_local();
6401        let cached = evaluator
6402            .evaluate_local(&[1.25, -0.75])
6403            .expect("evaluation should succeed");
6404        assert_eq!(cached.len(), expected_len);
6405    }
6406
6407    #[test]
6408    fn test_evaluate_gradient_local_does_not_depend_on_dataset_rows() {
6409        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
6410            .unwrap()
6411            .norm_sqr();
6412        let mut event1 = test_event();
6413        event1.p4s[0].t = 7.5;
6414        let mut event2 = test_event();
6415        event2.p4s[0].t = 8.25;
6416        let mut event3 = test_event();
6417        event3.p4s[0].t = 9.0;
6418        let dataset = Arc::new(Dataset::new_with_metadata(
6419            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
6420            Arc::new(DatasetMetadata::default()),
6421        ));
6422        let mut evaluator = expr.load(&dataset).unwrap();
6423        drop(dataset);
6424        let expected_len = evaluator.resources.read().caches.len();
6425        Arc::get_mut(&mut evaluator.dataset)
6426            .expect("evaluator should own dataset Arc in this test")
6427            .clear_events_local();
6428        let cached = evaluator
6429            .evaluate_gradient_local(&[1.25, -0.75])
6430            .expect("evaluation should succeed");
6431        assert_eq!(cached.len(), expected_len);
6432    }
6433
6434    #[test]
6435    fn test_evaluate_with_gradient_local_matches_separate_paths() {
6436        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
6437            .unwrap()
6438            .norm_sqr();
6439        let dataset = Arc::new(Dataset::new(vec![
6440            Arc::new(test_event()),
6441            Arc::new(test_event()),
6442            Arc::new(test_event()),
6443        ]));
6444        let evaluator = expr.load(&dataset).unwrap();
6445        let params = [1.25, -0.75];
6446        let values = evaluator
6447            .evaluate_local(&params)
6448            .expect("evaluation should succeed");
6449        let gradients = evaluator
6450            .evaluate_gradient_local(&params)
6451            .expect("evaluation should succeed");
6452        let fused = evaluator
6453            .evaluate_with_gradient_local(&params)
6454            .expect("evaluation should succeed");
6455        assert_eq!(fused.len(), values.len());
6456        assert_eq!(fused.len(), gradients.len());
6457        for ((value_gradient, value), gradient) in
6458            fused.iter().zip(values.iter()).zip(gradients.iter())
6459        {
6460            let (fused_value, fused_gradient) = value_gradient;
6461            assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
6462            assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
6463            assert_eq!(fused_gradient.len(), gradient.len());
6464            for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
6465                assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
6466                assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
6467            }
6468        }
6469    }
6470
6471    #[test]
6472    fn test_evaluate_with_gradient_batch_local_matches_separate_paths() {
6473        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
6474            .unwrap()
6475            .norm_sqr();
6476        let dataset = Arc::new(Dataset::new(vec![
6477            Arc::new(test_event()),
6478            Arc::new(test_event()),
6479            Arc::new(test_event()),
6480            Arc::new(test_event()),
6481        ]));
6482        let evaluator = expr.load(&dataset).unwrap();
6483        let params = [0.5, -1.25];
6484        let indices = vec![0, 2, 3];
6485        let values = evaluator
6486            .evaluate_batch_local(&params, &indices)
6487            .expect("evaluation should succeed");
6488        let gradients = evaluator
6489            .evaluate_gradient_batch_local(&params, &indices)
6490            .expect("evaluation should succeed");
6491        let fused = evaluator
6492            .evaluate_with_gradient_batch_local(&params, &indices)
6493            .expect("evaluation should succeed");
6494        assert_eq!(fused.len(), values.len());
6495        assert_eq!(fused.len(), gradients.len());
6496        for ((value_gradient, value), gradient) in
6497            fused.iter().zip(values.iter()).zip(gradients.iter())
6498        {
6499            let (fused_value, fused_gradient) = value_gradient;
6500            assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
6501            assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
6502            assert_eq!(fused_gradient.len(), gradient.len());
6503            for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
6504                assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
6505                assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
6506            }
6507        }
6508    }
6509
6510    #[test]
6511    fn test_precompute_all_columnar_populates_cache() {
6512        let mut event1 = test_event();
6513        event1.p4s[0].t = 7.5;
6514        let mut event2 = test_event();
6515        event2.p4s[0].t = 8.25;
6516        let mut event3 = test_event();
6517        event3.p4s[0].t = 9.0;
6518        let dataset = Dataset::new_with_metadata(
6519            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
6520            Arc::new(DatasetMetadata::default()),
6521        );
6522        let mut amplitude = TestAmplitude {
6523            tags: Tags::new(["test"]),
6524            re: parameter!("real"),
6525            pid_re: ParameterID::default(),
6526            im: parameter!("imag"),
6527            pid_im: ParameterID::default(),
6528            beam_energy: Default::default(),
6529        };
6530        let mut resources = Resources::default();
6531        amplitude
6532            .register(&mut resources)
6533            .expect("test amplitude should register");
6534        resources.reserve_cache(dataset.n_events());
6535        amplitude.precompute_all(&dataset, &mut resources);
6536        for cache in &resources.caches {
6537            assert!(cache.get_scalar(amplitude.beam_energy) > 0.0);
6538        }
6539    }
6540
6541    #[cfg(feature = "mpi")]
6542    #[mpi_test(np = [2])]
6543    fn test_load_reserves_local_cache_size_in_mpi() {
6544        use crate::mpi::{finalize_mpi, get_world, use_mpi};
6545
6546        use_mpi(true);
6547        assert!(get_world().is_some(), "MPI world should be initialized");
6548
6549        let expr = ComplexScalar::new(
6550            "constant",
6551            parameter!("const_re", 2.0),
6552            parameter!("const_im", 3.0),
6553        )
6554        .expect("constant amplitude should construct");
6555        let events = vec![
6556            Arc::new(test_event()),
6557            Arc::new(test_event()),
6558            Arc::new(test_event()),
6559            Arc::new(test_event()),
6560        ];
6561        let dataset = Arc::new(Dataset::new_with_metadata(
6562            events,
6563            Arc::new(DatasetMetadata::default()),
6564        ));
6565        let evaluator = expr.load(&dataset).expect("evaluator should load");
6566        let local_events = dataset.n_events_local();
6567        let cache_len = evaluator.resources.read().caches.len();
6568
6569        assert_eq!(
6570            cache_len, local_events,
6571            "cache length must match local event count under MPI"
6572        );
6573        finalize_mpi();
6574    }
6575
6576    #[cfg(feature = "mpi")]
6577    #[mpi_test(np = [2])]
6578    fn test_expression_ir_cached_integrals_are_rank_local_in_mpi() {
6579        use mpi::{collective::SystemOperation, topology::Communicator, traits::*};
6580
6581        use crate::mpi::{finalize_mpi, get_world, use_mpi};
6582
6583        use_mpi(true);
6584        let world = get_world().expect("MPI world should be initialized");
6585
6586        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6587            * &CacheOnlyScalar::new("k").unwrap();
6588        let events = vec![
6589            Arc::new(EventData {
6590                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
6591                aux: vec![],
6592                weight: 0.5,
6593            }),
6594            Arc::new(EventData {
6595                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6596                aux: vec![],
6597                weight: 1.0,
6598            }),
6599            Arc::new(EventData {
6600                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6601                aux: vec![],
6602                weight: 1.5,
6603            }),
6604            Arc::new(EventData {
6605                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
6606                aux: vec![],
6607                weight: 2.0,
6608            }),
6609            Arc::new(EventData {
6610                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6611                aux: vec![],
6612                weight: 2.5,
6613            }),
6614            Arc::new(EventData {
6615                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
6616                aux: vec![],
6617                weight: 3.0,
6618            }),
6619            Arc::new(EventData {
6620                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
6621                aux: vec![],
6622                weight: 3.5,
6623            }),
6624            Arc::new(EventData {
6625                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
6626                aux: vec![],
6627                weight: 4.0,
6628            }),
6629        ];
6630        let dataset = Arc::new(Dataset::new_with_metadata(
6631            events,
6632            Arc::new(DatasetMetadata::default()),
6633        ));
6634        let evaluator = expr.load(&dataset).expect("evaluator should load");
6635        let cached_integrals = evaluator
6636            .expression_precomputed_cached_integrals()
6637            .expect("integrals should exist");
6638        assert_eq!(cached_integrals.len(), 1);
6639
6640        let local_expected =
6641            dataset
6642                .weights_local()
6643                .iter()
6644                .enumerate()
6645                .fold(0.0, |acc, (index, weight)| {
6646                    let event = dataset.event_local(index).expect("event should exist");
6647                    acc + *weight * event.p4_at(0).e()
6648                });
6649        let cached_local = cached_integrals[0].weighted_cache_sum;
6650        assert_relative_eq!(cached_local.re, local_expected, epsilon = 1e-12);
6651        assert_relative_eq!(cached_local.im, 0.0, epsilon = 1e-12);
6652
6653        let weighted_value_sum = evaluator
6654            .evaluate_weighted_value_sum_local(&[2.0])
6655            .expect("evaluate should succeed");
6656        assert_relative_eq!(weighted_value_sum, 2.0 * local_expected, epsilon = 1e-10);
6657
6658        let mut global_expected = 0.0;
6659        world.all_reduce_into(
6660            &local_expected,
6661            &mut global_expected,
6662            SystemOperation::sum(),
6663        );
6664        if world.size() > 1 {
6665            assert!(
6666                (cached_local.re - global_expected).abs() > 1e-12,
6667                "cached integral should remain rank-local before MPI reduction"
6668            );
6669        }
6670        finalize_mpi();
6671    }
6672
6673    #[cfg(feature = "mpi")]
6674    #[mpi_test(np = [2])]
6675    fn test_expression_ir_weighted_sum_mpi_matches_global_eventwise_baseline() {
6676        use mpi::{collective::SystemOperation, traits::*};
6677
6678        use crate::mpi::{finalize_mpi, get_world, use_mpi};
6679
6680        use_mpi(true);
6681        let world = get_world().expect("MPI world should be initialized");
6682
6683        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6684        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6685        let c1 = CacheOnlyScalar::new("c1").unwrap();
6686        let c2 = CacheOnlyScalar::new("c2").unwrap();
6687        let c3 = CacheOnlyScalar::new("c3").unwrap();
6688        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6689        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6690        let events = vec![
6691            Arc::new(EventData {
6692                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
6693                aux: vec![],
6694                weight: 0.5,
6695            }),
6696            Arc::new(EventData {
6697                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6698                aux: vec![],
6699                weight: -1.25,
6700            }),
6701            Arc::new(EventData {
6702                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6703                aux: vec![],
6704                weight: 0.75,
6705            }),
6706            Arc::new(EventData {
6707                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
6708                aux: vec![],
6709                weight: 1.5,
6710            }),
6711            Arc::new(EventData {
6712                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6713                aux: vec![],
6714                weight: 2.25,
6715            }),
6716            Arc::new(EventData {
6717                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
6718                aux: vec![],
6719                weight: -0.5,
6720            }),
6721            Arc::new(EventData {
6722                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
6723                aux: vec![],
6724                weight: 3.5,
6725            }),
6726            Arc::new(EventData {
6727                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
6728                aux: vec![],
6729                weight: 1.25,
6730            }),
6731        ];
6732        let dataset = Arc::new(Dataset::new_with_metadata(
6733            events,
6734            Arc::new(DatasetMetadata::default()),
6735        ));
6736        let evaluator = expr.load(&dataset).expect("evaluator should load");
6737        let params = vec![0.2, -0.3, 1.1, -0.7];
6738
6739        let local_expected_value = evaluator
6740            .evaluate_local(&params)
6741            .expect("evaluate should succeed")
6742            .iter()
6743            .zip(dataset.weights_local().iter())
6744            .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6745        let mut global_expected_value = 0.0;
6746        world.all_reduce_into(
6747            &local_expected_value,
6748            &mut global_expected_value,
6749            SystemOperation::sum(),
6750        );
6751        let mpi_value = evaluator
6752            .evaluate_weighted_value_sum_mpi(&params, &world)
6753            .expect("evaluate should succeed");
6754        assert_relative_eq!(mpi_value, global_expected_value, epsilon = 1e-10);
6755
6756        let local_expected_gradient = evaluator
6757            .evaluate_gradient_local(&params)
6758            .expect("evaluate should succeed")
6759            .iter()
6760            .zip(dataset.weights_local().iter())
6761            .fold(
6762                DVector::zeros(params.len()),
6763                |mut accum, (gradient, event)| {
6764                    accum += gradient.map(|value| value.re).scale(*event);
6765                    accum
6766                },
6767            );
6768        let mut global_expected_gradient = vec![0.0; local_expected_gradient.len()];
6769        world.all_reduce_into(
6770            local_expected_gradient.as_slice(),
6771            &mut global_expected_gradient,
6772            SystemOperation::sum(),
6773        );
6774        let mpi_gradient = evaluator
6775            .evaluate_weighted_gradient_sum_mpi(&params, &world)
6776            .expect("evaluate should succeed");
6777        for (actual, expected) in mpi_gradient.iter().zip(global_expected_gradient.iter()) {
6778            assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
6779        }
6780
6781        finalize_mpi();
6782    }
6783
6784    #[test]
6785    fn test_evaluate_local_succeeds_for_constant_amplitude() {
6786        let expr = ComplexScalar::new(
6787            "constant",
6788            parameter!("const_re", 2.0),
6789            parameter!("const_im", 3.0),
6790        )
6791        .unwrap();
6792        let dataset = Arc::new(Dataset::new_with_metadata(
6793            vec![Arc::new(test_event())],
6794            Arc::new(DatasetMetadata::default()),
6795        ));
6796        let evaluator = expr.load(&dataset).unwrap();
6797        let values = evaluator
6798            .evaluate_local(&[])
6799            .expect("evaluation should succeed");
6800        assert_eq!(values.len(), 1);
6801        let gradients = evaluator
6802            .evaluate_gradient_local(&[])
6803            .expect("evaluation should succeed");
6804        assert_eq!(gradients.len(), 1);
6805    }
6806
6807    #[test]
6808    fn test_constant_amplitude() {
6809        let expr = ComplexScalar::new(
6810            "constant",
6811            parameter!("const_re", 2.0),
6812            parameter!("const_im", 3.0),
6813        )
6814        .unwrap();
6815        let dataset = Arc::new(Dataset::new_with_metadata(
6816            vec![Arc::new(test_event())],
6817            Arc::new(DatasetMetadata::default()),
6818        ));
6819        let evaluator = expr.load(&dataset).unwrap();
6820        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6821        assert_eq!(result[0], Complex64::new(2.0, 3.0));
6822    }
6823
6824    #[test]
6825    fn test_parametric_amplitude() {
6826        let expr = ComplexScalar::new(
6827            "parametric",
6828            parameter!("test_param_re"),
6829            parameter!("test_param_im"),
6830        )
6831        .unwrap();
6832        let dataset = Arc::new(test_dataset());
6833        let evaluator = expr.load(&dataset).unwrap();
6834        let result = evaluator
6835            .evaluate(&[2.0, 3.0])
6836            .expect("evaluation should succeed");
6837        assert_eq!(result[0], Complex64::new(2.0, 3.0));
6838    }
6839
6840    #[test]
6841    fn test_expression_operations() {
6842        let expr1 = ComplexScalar::new(
6843            "const1",
6844            parameter!("const1_re", 2.0),
6845            parameter!("const1_im", 0.0),
6846        )
6847        .unwrap();
6848        let expr2 = ComplexScalar::new(
6849            "const2",
6850            parameter!("const2_re", 0.0),
6851            parameter!("const2_im", 1.0),
6852        )
6853        .unwrap();
6854        let expr3 = ComplexScalar::new(
6855            "const3",
6856            parameter!("const3_re", 3.0),
6857            parameter!("const3_im", 4.0),
6858        )
6859        .unwrap();
6860
6861        let dataset = Arc::new(test_dataset());
6862
6863        // Test (amp) addition
6864        let expr_add = &expr1 + &expr2;
6865        let result_add = expr_add
6866            .load(&dataset)
6867            .unwrap()
6868            .evaluate(&[])
6869            .expect("evaluation should succeed");
6870        assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
6871
6872        // Test (amp) subtraction
6873        let expr_sub = &expr1 - &expr2;
6874        let result_sub = expr_sub
6875            .load(&dataset)
6876            .unwrap()
6877            .evaluate(&[])
6878            .expect("evaluation should succeed");
6879        assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
6880
6881        // Test (amp) multiplication
6882        let expr_mul = &expr1 * &expr2;
6883        let result_mul = expr_mul
6884            .load(&dataset)
6885            .unwrap()
6886            .evaluate(&[])
6887            .expect("evaluation should succeed");
6888        assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
6889
6890        // Test (amp) division
6891        let expr_div = &expr1 / &expr3;
6892        let result_div = expr_div
6893            .load(&dataset)
6894            .unwrap()
6895            .evaluate(&[])
6896            .expect("evaluation should succeed");
6897        assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
6898
6899        // Test (amp) neg
6900        let expr_neg = -&expr3;
6901        let result_neg = expr_neg
6902            .load(&dataset)
6903            .unwrap()
6904            .evaluate(&[])
6905            .expect("evaluation should succeed");
6906        assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
6907
6908        // Test (expr) addition
6909        let expr_add2 = &expr_add + &expr_mul;
6910        let result_add2 = expr_add2
6911            .load(&dataset)
6912            .unwrap()
6913            .evaluate(&[])
6914            .expect("evaluation should succeed");
6915        assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
6916
6917        // Test (expr) subtraction
6918        let expr_sub2 = &expr_add - &expr_mul;
6919        let result_sub2 = expr_sub2
6920            .load(&dataset)
6921            .unwrap()
6922            .evaluate(&[])
6923            .expect("evaluation should succeed");
6924        assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
6925
6926        // Test (expr) multiplication
6927        let expr_mul2 = &expr_add * &expr_mul;
6928        let result_mul2 = expr_mul2
6929            .load(&dataset)
6930            .unwrap()
6931            .evaluate(&[])
6932            .expect("evaluation should succeed");
6933        assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
6934
6935        // Test (expr) division
6936        let expr_div2 = &expr_add / &expr_add2;
6937        let result_div2 = expr_div2
6938            .load(&dataset)
6939            .unwrap()
6940            .evaluate(&[])
6941            .expect("evaluation should succeed");
6942        assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
6943
6944        // Test (expr) neg
6945        let expr_neg2 = -&expr_mul2;
6946        let result_neg2 = expr_neg2
6947            .load(&dataset)
6948            .unwrap()
6949            .evaluate(&[])
6950            .expect("evaluation should succeed");
6951        assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
6952
6953        // Test (amp) real
6954        let expr_real = expr3.real();
6955        let result_real = expr_real
6956            .load(&dataset)
6957            .unwrap()
6958            .evaluate(&[])
6959            .expect("evaluation should succeed");
6960        assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
6961
6962        // Test (expr) real
6963        let expr_mul2_real = expr_mul2.real();
6964        let result_mul2_real = expr_mul2_real
6965            .load(&dataset)
6966            .unwrap()
6967            .evaluate(&[])
6968            .expect("evaluation should succeed");
6969        assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
6970
6971        // Test (amp) imag
6972        let expr_imag = expr3.imag();
6973        let result_imag = expr_imag
6974            .load(&dataset)
6975            .unwrap()
6976            .evaluate(&[])
6977            .expect("evaluation should succeed");
6978        assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
6979
6980        // Test (expr) imag
6981        let expr_mul2_imag = expr_mul2.imag();
6982        let result_mul2_imag = expr_mul2_imag
6983            .load(&dataset)
6984            .unwrap()
6985            .evaluate(&[])
6986            .expect("evaluation should succeed");
6987        assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
6988
6989        // Test (amp) conj
6990        let expr_conj = expr3.conj();
6991        let result_conj = expr_conj
6992            .load(&dataset)
6993            .unwrap()
6994            .evaluate(&[])
6995            .expect("evaluation should succeed");
6996        assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
6997
6998        // Test (expr) conj
6999        let expr_mul2_conj = expr_mul2.conj();
7000        let result_mul2_conj = expr_mul2_conj
7001            .load(&dataset)
7002            .unwrap()
7003            .evaluate(&[])
7004            .expect("evaluation should succeed");
7005        assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
7006
7007        // Test (amp) norm_sqr
7008        let expr_norm = expr1.norm_sqr();
7009        let result_norm = expr_norm
7010            .load(&dataset)
7011            .unwrap()
7012            .evaluate(&[])
7013            .expect("evaluation should succeed");
7014        assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
7015
7016        // Test (expr) norm_sqr
7017        let expr_mul2_norm = expr_mul2.norm_sqr();
7018        let result_mul2_norm = expr_mul2_norm
7019            .load(&dataset)
7020            .unwrap()
7021            .evaluate(&[])
7022            .expect("evaluation should succeed");
7023        assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
7024    }
7025
7026    #[test]
7027    fn test_amplitude_activation() {
7028        let expr1 = ComplexScalar::new(
7029            "const1",
7030            parameter!("const1_re_act", 1.0),
7031            parameter!("const1_im_act", 0.0),
7032        )
7033        .unwrap();
7034        let expr2 = ComplexScalar::new(
7035            "const2",
7036            parameter!("const2_re_act", 2.0),
7037            parameter!("const2_im_act", 0.0),
7038        )
7039        .unwrap();
7040
7041        let dataset = Arc::new(test_dataset());
7042        let expr = &expr1 + &expr2;
7043        let evaluator = expr.load(&dataset).unwrap();
7044
7045        // Test initial state (all active)
7046        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7047        assert_eq!(result[0], Complex64::new(3.0, 0.0));
7048
7049        // Test deactivation
7050        evaluator.deactivate_strict("const1").unwrap();
7051        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7052        assert_eq!(result[0], Complex64::new(2.0, 0.0));
7053
7054        // Test isolation
7055        evaluator.isolate_strict("const1").unwrap();
7056        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7057        assert_eq!(result[0], Complex64::new(1.0, 0.0));
7058
7059        // Test reactivation
7060        evaluator.activate_all();
7061        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7062        assert_eq!(result[0], Complex64::new(3.0, 0.0));
7063    }
7064
7065    #[test]
7066    fn test_gradient() {
7067        let expr1 = ComplexScalar::new(
7068            "parametric_1",
7069            parameter!("test_param_re_1"),
7070            parameter!("test_param_im_1"),
7071        )
7072        .unwrap();
7073        let expr2 = ComplexScalar::new(
7074            "parametric_2",
7075            parameter!("test_param_re_2"),
7076            parameter!("test_param_im_2"),
7077        )
7078        .unwrap();
7079
7080        let dataset = Arc::new(test_dataset());
7081        let params = vec![2.0, 3.0, 4.0, 5.0];
7082
7083        let expr = &expr1 + &expr2;
7084        let evaluator = expr.load(&dataset).unwrap();
7085
7086        let gradient = evaluator
7087            .evaluate_gradient(&params)
7088            .expect("evaluation should succeed");
7089
7090        assert_relative_eq!(gradient[0][0].re, 1.0);
7091        assert_relative_eq!(gradient[0][0].im, 0.0);
7092        assert_relative_eq!(gradient[0][1].re, 0.0);
7093        assert_relative_eq!(gradient[0][1].im, 1.0);
7094        assert_relative_eq!(gradient[0][2].re, 1.0);
7095        assert_relative_eq!(gradient[0][2].im, 0.0);
7096        assert_relative_eq!(gradient[0][3].re, 0.0);
7097        assert_relative_eq!(gradient[0][3].im, 1.0);
7098
7099        let expr = &expr1 - &expr2;
7100        let evaluator = expr.load(&dataset).unwrap();
7101
7102        let gradient = evaluator
7103            .evaluate_gradient(&params)
7104            .expect("evaluation should succeed");
7105
7106        assert_relative_eq!(gradient[0][0].re, 1.0);
7107        assert_relative_eq!(gradient[0][0].im, 0.0);
7108        assert_relative_eq!(gradient[0][1].re, 0.0);
7109        assert_relative_eq!(gradient[0][1].im, 1.0);
7110        assert_relative_eq!(gradient[0][2].re, -1.0);
7111        assert_relative_eq!(gradient[0][2].im, 0.0);
7112        assert_relative_eq!(gradient[0][3].re, 0.0);
7113        assert_relative_eq!(gradient[0][3].im, -1.0);
7114
7115        let expr = &expr1 * &expr2;
7116        let evaluator = expr.load(&dataset).unwrap();
7117
7118        let gradient = evaluator
7119            .evaluate_gradient(&params)
7120            .expect("evaluation should succeed");
7121
7122        assert_relative_eq!(gradient[0][0].re, 4.0);
7123        assert_relative_eq!(gradient[0][0].im, 5.0);
7124        assert_relative_eq!(gradient[0][1].re, -5.0);
7125        assert_relative_eq!(gradient[0][1].im, 4.0);
7126        assert_relative_eq!(gradient[0][2].re, 2.0);
7127        assert_relative_eq!(gradient[0][2].im, 3.0);
7128        assert_relative_eq!(gradient[0][3].re, -3.0);
7129        assert_relative_eq!(gradient[0][3].im, 2.0);
7130
7131        let expr = &expr1 / &expr2;
7132        let evaluator = expr.load(&dataset).unwrap();
7133
7134        let gradient = evaluator
7135            .evaluate_gradient(&params)
7136            .expect("evaluation should succeed");
7137
7138        assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
7139        assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
7140        assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
7141        assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
7142        assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
7143        assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
7144        assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
7145        assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
7146
7147        let expr = -(&expr1 * &expr2);
7148        let evaluator = expr.load(&dataset).unwrap();
7149
7150        let gradient = evaluator
7151            .evaluate_gradient(&params)
7152            .expect("evaluation should succeed");
7153
7154        assert_relative_eq!(gradient[0][0].re, -4.0);
7155        assert_relative_eq!(gradient[0][0].im, -5.0);
7156        assert_relative_eq!(gradient[0][1].re, 5.0);
7157        assert_relative_eq!(gradient[0][1].im, -4.0);
7158        assert_relative_eq!(gradient[0][2].re, -2.0);
7159        assert_relative_eq!(gradient[0][2].im, -3.0);
7160        assert_relative_eq!(gradient[0][3].re, 3.0);
7161        assert_relative_eq!(gradient[0][3].im, -2.0);
7162
7163        let expr = (&expr1 * &expr2).real();
7164        let evaluator = expr.load(&dataset).unwrap();
7165
7166        let gradient = evaluator
7167            .evaluate_gradient(&params)
7168            .expect("evaluation should succeed");
7169
7170        assert_relative_eq!(gradient[0][0].re, 4.0);
7171        assert_relative_eq!(gradient[0][0].im, 0.0);
7172        assert_relative_eq!(gradient[0][1].re, -5.0);
7173        assert_relative_eq!(gradient[0][1].im, 0.0);
7174        assert_relative_eq!(gradient[0][2].re, 2.0);
7175        assert_relative_eq!(gradient[0][2].im, 0.0);
7176        assert_relative_eq!(gradient[0][3].re, -3.0);
7177        assert_relative_eq!(gradient[0][3].im, 0.0);
7178
7179        let expr = (&expr1 * &expr2).imag();
7180        let evaluator = expr.load(&dataset).unwrap();
7181
7182        let gradient = evaluator
7183            .evaluate_gradient(&params)
7184            .expect("evaluation should succeed");
7185
7186        assert_relative_eq!(gradient[0][0].re, 5.0);
7187        assert_relative_eq!(gradient[0][0].im, 0.0);
7188        assert_relative_eq!(gradient[0][1].re, 4.0);
7189        assert_relative_eq!(gradient[0][1].im, 0.0);
7190        assert_relative_eq!(gradient[0][2].re, 3.0);
7191        assert_relative_eq!(gradient[0][2].im, 0.0);
7192        assert_relative_eq!(gradient[0][3].re, 2.0);
7193        assert_relative_eq!(gradient[0][3].im, 0.0);
7194
7195        let expr = (&expr1 * &expr2).conj();
7196        let evaluator = expr.load(&dataset).unwrap();
7197
7198        let gradient = evaluator
7199            .evaluate_gradient(&params)
7200            .expect("evaluation should succeed");
7201
7202        assert_relative_eq!(gradient[0][0].re, 4.0);
7203        assert_relative_eq!(gradient[0][0].im, -5.0);
7204        assert_relative_eq!(gradient[0][1].re, -5.0);
7205        assert_relative_eq!(gradient[0][1].im, -4.0);
7206        assert_relative_eq!(gradient[0][2].re, 2.0);
7207        assert_relative_eq!(gradient[0][2].im, -3.0);
7208        assert_relative_eq!(gradient[0][3].re, -3.0);
7209        assert_relative_eq!(gradient[0][3].im, -2.0);
7210
7211        let expr = (&expr1 * &expr2).norm_sqr();
7212        let evaluator = expr.load(&dataset).unwrap();
7213
7214        let gradient = evaluator
7215            .evaluate_gradient(&params)
7216            .expect("evaluation should succeed");
7217
7218        assert_relative_eq!(gradient[0][0].re, 164.0);
7219        assert_relative_eq!(gradient[0][0].im, 0.0);
7220        assert_relative_eq!(gradient[0][1].re, 246.0);
7221        assert_relative_eq!(gradient[0][1].im, 0.0);
7222        assert_relative_eq!(gradient[0][2].re, 104.0);
7223        assert_relative_eq!(gradient[0][2].im, 0.0);
7224        assert_relative_eq!(gradient[0][3].re, 130.0);
7225        assert_relative_eq!(gradient[0][3].im, 0.0);
7226    }
7227
7228    #[test]
7229    fn test_expression_function_gradients() {
7230        let expr1 = ComplexScalar::new(
7231            "function_parametric_1",
7232            parameter!("function_test_param_re_1"),
7233            parameter!("function_test_param_im_1"),
7234        )
7235        .unwrap();
7236        let expr2 = ComplexScalar::new(
7237            "function_parametric_2",
7238            parameter!("function_test_param_re_2"),
7239            parameter!("function_test_param_im_2"),
7240        )
7241        .unwrap();
7242
7243        let sin = expr1.sin();
7244        let cos = expr1.cos();
7245        let trig = &sin * &cos;
7246        let pow = expr1.pow(&expr2);
7247        let mut expr = expr1.sqrt();
7248        expr = &expr + &expr1.exp();
7249        expr = &expr + &expr1.powi(2);
7250        expr = &expr + &expr1.powf(1.7);
7251        expr = &expr + &trig;
7252        expr = &expr + &expr1.log();
7253        expr = &expr + &expr1.cis();
7254        expr = &expr + &pow;
7255
7256        let dataset = Arc::new(test_dataset());
7257        let evaluator = expr.load(&dataset).unwrap();
7258        let params = vec![2.0, 0.5, 1.2, -0.3];
7259        let gradient = evaluator
7260            .evaluate_gradient(&params)
7261            .expect("evaluation should succeed");
7262        let eps = 1e-6;
7263
7264        for param_index in 0..params.len() {
7265            let mut plus = params.clone();
7266            plus[param_index] += eps;
7267            let mut minus = params.clone();
7268            minus[param_index] -= eps;
7269            let finite_diff = (evaluator
7270                .evaluate(&plus)
7271                .expect("evaluation should succeed")[0]
7272                - evaluator
7273                    .evaluate(&minus)
7274                    .expect("evaluation should succeed")[0])
7275                / Complex64::new(2.0 * eps, 0.0);
7276
7277            assert_relative_eq!(
7278                gradient[0][param_index].re,
7279                finite_diff.re,
7280                epsilon = 1e-6,
7281                max_relative = 1e-6
7282            );
7283            assert_relative_eq!(
7284                gradient[0][param_index].im,
7285                finite_diff.im,
7286                epsilon = 1e-6,
7287                max_relative = 1e-6
7288            );
7289        }
7290    }
7291
7292    #[test]
7293    fn test_zeros_and_ones() {
7294        let amp = ComplexScalar::new(
7295            "parametric",
7296            parameter!("test_param_re"),
7297            parameter!("fixed_two", 2.0),
7298        )
7299        .unwrap();
7300        let dataset = Arc::new(test_dataset());
7301        let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
7302        let evaluator = expr.load(&dataset).unwrap();
7303
7304        let params = vec![2.0];
7305        let value = evaluator
7306            .evaluate(&params)
7307            .expect("evaluation should succeed");
7308        let gradient = evaluator
7309            .evaluate_gradient(&params)
7310            .expect("evaluation should succeed");
7311
7312        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the value should be x^2 + 4
7313        assert_relative_eq!(value[0].re, 8.0);
7314        assert_relative_eq!(value[0].im, 0.0);
7315
7316        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the derivative should be 2x
7317        assert_relative_eq!(gradient[0][0].re, 4.0);
7318        assert_relative_eq!(gradient[0][0].im, 0.0);
7319    }
7320    #[test]
7321    fn test_default_build_uses_lowered_expression_runtime() {
7322        let expr = ComplexScalar::new(
7323            "opt_in_gate",
7324            parameter!("opt_in_gate_re", 2.0),
7325            parameter!("opt_in_gate_im", 0.0),
7326        )
7327        .unwrap()
7328        .norm_sqr();
7329        let dataset = Arc::new(test_dataset());
7330        let evaluator = expr.load(&dataset).unwrap();
7331
7332        let diagnostics = evaluator.expression_runtime_diagnostics();
7333        assert!(diagnostics.ir_planning_enabled);
7334        assert!(diagnostics.lowered_value_program_present);
7335        assert!(diagnostics.lowered_gradient_program_present);
7336        assert!(diagnostics.lowered_value_gradient_program_present);
7337        assert_eq!(
7338            evaluator.evaluate(&[]).expect("evaluation should succeed")[0],
7339            Complex64::new(4.0, 0.0)
7340        );
7341    }
7342
7343    #[test]
7344    fn parameter_name_only_creates_free_parameter() {
7345        let p = parameter!("mass");
7346
7347        assert_eq!(p.name(), "mass");
7348        assert_eq!(p.fixed(), None);
7349        assert_eq!(p.initial(), None);
7350        assert_eq!(p.bounds(), (None, None));
7351        assert_eq!(p.unit(), None);
7352        assert_eq!(p.latex(), None);
7353        assert_eq!(p.description(), None);
7354        assert!(p.is_free());
7355        assert!(!p.is_fixed());
7356    }
7357
7358    #[test]
7359    fn parameter_name_and_value_creates_fixed_parameter() {
7360        let p = parameter!("width", 0.15);
7361
7362        assert_eq!(p.name(), "width");
7363        assert_eq!(p.fixed(), Some(0.15));
7364        assert_eq!(p.initial(), Some(0.15));
7365        assert!(p.is_fixed());
7366        assert!(!p.is_free());
7367    }
7368
7369    #[test]
7370    fn keyword_initial_sets_initial_only() {
7371        let p = parameter!("alpha", initial: 1.25);
7372
7373        assert_eq!(p.name(), "alpha");
7374        assert_eq!(p.fixed(), None);
7375        assert_eq!(p.initial(), Some(1.25));
7376        assert_eq!(p.bounds(), (None, None));
7377        assert!(p.is_free());
7378    }
7379
7380    #[test]
7381    fn keyword_fixed_sets_fixed_and_initial() {
7382        let p = parameter!("beta", fixed: 2.5);
7383
7384        assert_eq!(p.name(), "beta");
7385        assert_eq!(p.fixed(), Some(2.5));
7386        assert_eq!(p.initial(), Some(2.5));
7387        assert!(p.is_fixed());
7388    }
7389
7390    #[test]
7391    fn bounds_accept_plain_numbers() {
7392        let p = parameter!("x", bounds: (0.0, 10.0));
7393
7394        assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
7395    }
7396
7397    #[test]
7398    fn bounds_accept_none_and_number() {
7399        let p = parameter!("x", bounds: (None, 10.0));
7400
7401        assert_eq!(p.bounds(), (None, Some(10.0)));
7402    }
7403
7404    #[test]
7405    fn bounds_accept_number_and_none() {
7406        let p = parameter!("x", bounds: (-1.0, None));
7407
7408        assert_eq!(p.bounds(), (Some(-1.0), None));
7409    }
7410
7411    #[test]
7412    fn bounds_accept_both_none() {
7413        let p = parameter!("x", bounds: (None, None));
7414
7415        assert_eq!(p.bounds(), (None, None));
7416    }
7417
7418    #[test]
7419    fn bounds_accept_arbitrary_expressions() {
7420        let lo = 1.0;
7421        let hi = 2.0 * 3.0;
7422        let p = parameter!("x", bounds: (lo - 0.5, hi));
7423
7424        assert_eq!(p.bounds(), (Some(0.5), Some(6.0)));
7425    }
7426
7427    #[test]
7428    fn multiple_keyword_arguments_work_together() {
7429        let p = parameter!(
7430            "gamma",
7431            initial: 1.0,
7432            bounds: (0.0, 5.0),
7433            unit: "GeV",
7434            latex: r"\gamma",
7435            description: "test parameter",
7436        );
7437
7438        assert_eq!(p.name(), "gamma");
7439        assert_eq!(p.fixed(), None);
7440        assert_eq!(p.initial(), Some(1.0));
7441        assert_eq!(p.bounds(), (Some(0.0), Some(5.0)));
7442        assert_eq!(p.unit().as_deref(), Some("GeV"));
7443        assert_eq!(p.latex().as_deref(), Some(r"\gamma"));
7444        assert_eq!(p.description().as_deref(), Some("test parameter"));
7445    }
7446
7447    #[test]
7448    fn fixed_can_be_combined_with_other_fields() {
7449        let p = parameter!(
7450            "delta",
7451            fixed: 3.0,
7452            bounds: (0.0, 10.0),
7453            unit: "rad",
7454        );
7455
7456        assert_eq!(p.name(), "delta");
7457        assert_eq!(p.fixed(), Some(3.0));
7458        assert_eq!(p.initial(), Some(3.0));
7459        assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
7460        assert_eq!(p.unit().as_deref(), Some("rad"));
7461    }
7462
7463    #[test]
7464    fn trailing_comma_is_accepted() {
7465        let p = parameter!(
7466            "eps",
7467            initial: 0.5,
7468            bounds: (None, 1.0),
7469            unit: "arb",
7470        );
7471
7472        assert_eq!(p.initial(), Some(0.5));
7473        assert_eq!(p.bounds(), (None, Some(1.0)));
7474        assert_eq!(p.unit().as_deref(), Some("arb"));
7475    }
7476
7477    #[test]
7478    fn test_parameter_registration() {
7479        let expr = ComplexScalar::new(
7480            "parametric",
7481            parameter!("test_param_re"),
7482            parameter!("fixed_two", 2.0),
7483        )
7484        .unwrap();
7485        let parameters = expr.parameters().free().names();
7486        assert_eq!(parameters.len(), 1);
7487        assert_eq!(parameters[0], "test_param_re");
7488    }
7489
7490    #[test]
7491    fn test_duplicate_amplitude_tag_registration_is_allowed() {
7492        let amp1 = ComplexScalar::new(
7493            "same_name",
7494            parameter!("dup_re1", 1.0),
7495            parameter!("dup_im1", 0.0),
7496        )
7497        .unwrap();
7498        let amp2 = ComplexScalar::new(
7499            "same_name",
7500            parameter!("dup_re2", 2.0),
7501            parameter!("dup_im2", 0.0),
7502        )
7503        .unwrap();
7504        let expr = amp1 + amp2;
7505        assert_eq!(
7506            expr.parameters().fixed().names(),
7507            vec!["dup_re1", "dup_im1", "dup_re2", "dup_im2"]
7508        );
7509    }
7510
7511    #[test]
7512    fn test_tree_printing() {
7513        let amp1 = ComplexScalar::new(
7514            "parametric_1",
7515            parameter!("test_param_re_1"),
7516            parameter!("test_param_im_1"),
7517        )
7518        .unwrap();
7519        let amp2 = ComplexScalar::new(
7520            "parametric_2",
7521            parameter!("test_param_re_2"),
7522            parameter!("test_param_im_2"),
7523        )
7524        .unwrap();
7525        let expr =
7526            &amp1.real() + &amp2.conj().imag() + Expression::one() * Complex64::new(-1.4, 2.0)
7527                - Expression::zero() / 1.0
7528                + (&amp1 * &amp2).norm_sqr();
7529        assert_eq!(
7530            expr.to_string(),
7531            concat!(
7532                "+\n",
7533                "├─ -\n",
7534                "│  ├─ +\n",
7535                "│  │  ├─ +\n",
7536                "│  │  │  ├─ Re\n",
7537                "│  │  │  │  └─ parametric_1(id=0)\n",
7538                "│  │  │  └─ Im\n",
7539                "│  │  │     └─ *\n",
7540                "│  │  │        └─ parametric_2(id=1)\n",
7541                "│  │  └─ ×\n",
7542                "│  │     ├─ 1 (exact)\n",
7543                "│  │     └─ -1.4+2i\n",
7544                "│  └─ ÷\n",
7545                "│     ├─ 0 (exact)\n",
7546                "│     └─ 1 (exact)\n",
7547                "└─ NormSqr\n",
7548                "   └─ ×\n",
7549                "      ├─ parametric_1(id=0)\n",
7550                "      └─ parametric_2(id=1)\n",
7551            )
7552        );
7553    }
7554}