Skip to main content

laddu_core/
amplitudes.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::ops::Index;
4use std::time::Instant;
5use std::{
6    fmt::{Debug, Display},
7    sync::{
8        atomic::{AtomicU64, Ordering},
9        Arc,
10    },
11};
12
13use auto_ops::*;
14use dyn_clone::DynClone;
15use nalgebra::DVector;
16use num::complex::Complex64;
17
18use parking_lot::{Mutex, RwLock};
19#[cfg(feature = "rayon")]
20use rayon::prelude::*;
21use serde::{Deserialize, Serialize};
22
23static AMPLITUDE_INSTANCE_COUNTER: AtomicU64 = AtomicU64::new(0);
24
25fn next_amplitude_id() -> u64 {
26    AMPLITUDE_INSTANCE_COUNTER.fetch_add(1, Ordering::Relaxed)
27}
28#[allow(dead_code)]
29mod ir;
30#[allow(dead_code)]
31mod lowered;
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34/// Dependence classification used by expression-IR diagnostics.
35pub enum ExpressionDependence {
36    /// Depends only on fixed/free parameter values.
37    ParameterOnly,
38    /// Depends only on event-local cached values.
39    CacheOnly,
40    /// Depends on both parameter values and cached event values.
41    Mixed,
42}
43impl From<ir::DependenceClass> for ExpressionDependence {
44    fn from(value: ir::DependenceClass) -> Self {
45        match value {
46            ir::DependenceClass::ParameterOnly => Self::ParameterOnly,
47            ir::DependenceClass::CacheOnly => Self::CacheOnly,
48            ir::DependenceClass::Mixed => Self::Mixed,
49        }
50    }
51}
52#[derive(Clone, Debug, PartialEq, Eq)]
53/// Explain/debug view of the IR normalization planning decomposition.
54pub struct NormalizationPlanExplain {
55    /// Dependence classification at the expression root.
56    pub root_dependence: ExpressionDependence,
57    /// Warning-level diagnostics collected during planning.
58    pub warnings: Vec<String>,
59    /// Candidate multiply node indices identified as separable.
60    pub separable_mul_candidate_nodes: Vec<usize>,
61    /// Candidate separable node indices selected for caching.
62    pub cached_separable_nodes: Vec<usize>,
63    /// Node indices planned for residual per-event evaluation.
64    pub residual_terms: Vec<usize>,
65}
66#[derive(Clone, Debug, PartialEq, Eq)]
67/// Explain/debug view of amplitude execution sets used by normalization evaluation.
68pub struct NormalizationExecutionSetsExplain {
69    /// Amplitudes required to evaluate parameter factors for cached separable terms.
70    pub cached_parameter_amplitudes: Vec<usize>,
71    /// Amplitudes required to evaluate cache factors for cached separable terms.
72    pub cached_cache_amplitudes: Vec<usize>,
73    /// Amplitudes required for residual (non-cached) normalization evaluation.
74    pub residual_amplitudes: Vec<usize>,
75}
76#[derive(Clone, Debug, PartialEq)]
77/// Load-time precomputed integral metadata for a separable cached term.
78pub struct PrecomputedCachedIntegral {
79    /// Node index of the separable multiplication term.
80    pub mul_node_index: usize,
81    /// Node index of the parameter-dependent factor.
82    pub parameter_node_index: usize,
83    /// Node index of the cache-dependent factor.
84    pub cache_node_index: usize,
85    /// Signed extraction coefficient induced by Add/Sub/Neg ancestry to the root.
86    pub coefficient: i32,
87    /// Weighted sum over local events of the cache-dependent factor.
88    pub weighted_cache_sum: Complex64,
89}
90#[derive(Clone, Debug, PartialEq)]
91/// Parameter-gradient contribution for a load-time precomputed cached integral term.
92pub struct PrecomputedCachedIntegralGradientTerm {
93    /// Node index of the separable multiplication term.
94    pub mul_node_index: usize,
95    /// Node index of the parameter-dependent factor.
96    pub parameter_node_index: usize,
97    /// Node index of the cache-dependent factor.
98    pub cache_node_index: usize,
99    /// Signed extraction coefficient induced by Add/Sub/Neg ancestry to the root.
100    pub coefficient: i32,
101    /// Gradient contribution `(d/dp parameter_factor) * weighted_cache_sum`.
102    pub weighted_gradient: DVector<Complex64>,
103}
104#[derive(Clone, Debug, PartialEq, Eq, Hash)]
105struct CachedIntegralCacheKey {
106    active_mask: Vec<bool>,
107    n_events_local: usize,
108    events_local_len: usize,
109    weighted_sum_bits: u64,
110    events_ptr: usize,
111}
112#[derive(Clone, Debug)]
113struct CachedIntegralCacheState {
114    key: CachedIntegralCacheKey,
115    expression_ir: ir::ExpressionIR,
116    values: Vec<PrecomputedCachedIntegral>,
117    execution_sets: ir::NormalizationExecutionSets,
118}
119#[derive(Clone, Debug)]
120struct LoweredArtifactCacheState {
121    parameter_node_indices: Vec<usize>,
122    mul_node_indices: Vec<usize>,
123    lowered_parameter_factors: Vec<Option<lowered::LoweredFactorRuntime>>,
124    residual_runtime: Option<lowered::LoweredExpressionRuntime>,
125    lowered_runtime: lowered::LoweredExpressionRuntime,
126}
127#[derive(Clone)]
128struct ExpressionSpecializationState {
129    cached_integrals: Arc<CachedIntegralCacheState>,
130    lowered_artifacts: Arc<LoweredArtifactCacheState>,
131}
132#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
133/// Debug/benchmark counters for active-mask specialization reuse under `expression-ir`.
134pub struct ExpressionSpecializationMetrics {
135    /// Number of specialization cache hits served without recompilation.
136    pub cache_hits: usize,
137    /// Number of specialization cache misses that required a fresh compile/lower pass.
138    pub cache_misses: usize,
139}
140#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
141/// Staged compile/lowering metrics for expression-IR construction and specialization refreshes.
142pub struct ExpressionCompileMetrics {
143    /// Nanoseconds spent compiling the semantic expression tree into IR during initial load.
144    pub initial_ir_compile_nanos: u64,
145    /// Nanoseconds spent precomputing cached-integral planning artifacts during initial load.
146    pub initial_cached_integrals_nanos: u64,
147    /// Nanoseconds spent lowering IR-derived runtimes during initial load.
148    pub initial_lowering_nanos: u64,
149    /// Number of specialization cache hits restored without recompilation.
150    pub specialization_cache_hits: usize,
151    /// Number of specialization cache misses that required recompilation.
152    pub specialization_cache_misses: usize,
153    /// Accumulated nanoseconds spent compiling active-mask-specialized IR after initial load.
154    pub specialization_ir_compile_nanos: u64,
155    /// Accumulated nanoseconds spent recomputing cached-integral planning artifacts after load.
156    pub specialization_cached_integrals_nanos: u64,
157    /// Accumulated nanoseconds spent lowering specialized runtimes after load.
158    pub specialization_lowering_nanos: u64,
159    /// Number of specialization rebuilds that reused cached lowered artifacts.
160    pub specialization_lowering_cache_hits: usize,
161    /// Number of specialization rebuilds that had to lower fresh artifacts.
162    pub specialization_lowering_cache_misses: usize,
163    /// Accumulated nanoseconds spent restoring specializations from cache.
164    pub specialization_cache_restore_nanos: u64,
165}
166impl From<ir::NormalizationPlanExplain> for NormalizationPlanExplain {
167    fn from(value: ir::NormalizationPlanExplain) -> Self {
168        Self {
169            root_dependence: value.root_dependence.into(),
170            warnings: value.warnings,
171            separable_mul_candidate_nodes: value
172                .separable_mul_candidates
173                .into_iter()
174                .map(|candidate| candidate.node_index)
175                .collect(),
176            cached_separable_nodes: value.cached_separable_nodes,
177            residual_terms: value.residual_terms,
178        }
179    }
180}
181impl From<ir::NormalizationExecutionSets> for NormalizationExecutionSetsExplain {
182    fn from(value: ir::NormalizationExecutionSets) -> Self {
183        Self {
184            cached_parameter_amplitudes: value.cached_parameter_amplitudes,
185            cached_cache_amplitudes: value.cached_cache_amplitudes,
186            residual_amplitudes: value.residual_amplitudes,
187        }
188    }
189}
190impl From<ExpressionDependence> for ir::DependenceClass {
191    fn from(value: ExpressionDependence) -> Self {
192        match value {
193            ExpressionDependence::ParameterOnly => Self::ParameterOnly,
194            ExpressionDependence::CacheOnly => Self::CacheOnly,
195            ExpressionDependence::Mixed => Self::Mixed,
196        }
197    }
198}
199
200#[cfg(feature = "execution-context-prototype")]
201use crate::ExecutionContext;
202#[cfg(all(feature = "execution-context-prototype", feature = "rayon"))]
203use crate::ThreadPolicy;
204use crate::{
205    data::{Dataset, DatasetMetadata, NamedEventView},
206    resources::{Cache, Parameters, Resources},
207    LadduError, LadduResult, ParameterID, ReadWrite,
208};
209
210#[cfg(feature = "mpi")]
211use crate::mpi::LadduMPI;
212
213#[cfg(feature = "mpi")]
214use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
215
216#[derive(Clone, Default, Serialize, Deserialize, Debug)]
217struct ParameterMetadata {
218    /// The name of the parameter.
219    name: String,
220    /// If `Some`, this parameter is fixed to the given value. If `None`, it is free.
221    fixed: Option<f64>,
222    /// If `Some`, this is used for the initial value of the parameter in fits. If `None`, the user
223    /// must provide the initial value on their own.
224    initial: Option<f64>,
225    /// Optional bounds which may be automatically used by optimizers. `None` represents no bound
226    /// in the given direction.
227    bounds: (Option<f64>, Option<f64>),
228    /// An optional unit string which may be used to annotate the parameter.
229    unit: Option<String>,
230    /// Optional LaTeX representation of the parameter.
231    latex: Option<String>,
232    /// Optional description of the parameter.
233    description: Option<String>,
234}
235
236/// An enum containing either a named free parameter or a constant value.
237#[derive(Clone, Default, Serialize, Deserialize, Debug)]
238pub struct Parameter(Arc<Mutex<ParameterMetadata>>);
239
240// NOTE: hash and equality only depend on name
241impl PartialEq for Parameter {
242    fn eq(&self, other: &Self) -> bool {
243        self.0.lock().name == other.0.lock().name
244    }
245}
246impl Eq for Parameter {}
247impl Hash for Parameter {
248    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
249        self.0.lock().name.hash(state);
250    }
251}
252
253/// Helper trait to convert values to bounds-like [`Option<f64>`]
254pub trait IntoBound {
255    /// Convert to a bound
256    fn into_bound(self) -> Option<f64>;
257}
258impl IntoBound for f64 {
259    fn into_bound(self) -> Option<f64> {
260        Some(self)
261    }
262}
263impl IntoBound for Option<f64> {
264    fn into_bound(self) -> Option<f64> {
265        self
266    }
267}
268
269impl Parameter {
270    /// Create a free (floating) parameter with the given name.
271    pub fn new(name: impl Into<String>) -> Self {
272        Self(Arc::new(Mutex::new(ParameterMetadata {
273            name: name.into(),
274            ..Default::default()
275        })))
276    }
277
278    /// Create a fixed parameter with the given name and value.
279    pub fn new_fixed(name: impl Into<String>, value: f64) -> Self {
280        Self(Arc::new(Mutex::new(ParameterMetadata {
281            name: name.into(),
282            fixed: Some(value),
283            ..Default::default()
284        })))
285    }
286
287    /// Return the current parameter name.
288    pub fn name(&self) -> String {
289        self.0.lock().name.clone()
290    }
291
292    /// Return the fixed value when the parameter is fixed.
293    pub fn fixed(&self) -> Option<f64> {
294        self.0.lock().fixed
295    }
296
297    /// Return the current initial value, if one is set.
298    pub fn initial(&self) -> Option<f64> {
299        self.0.lock().initial
300    }
301
302    /// Return the current lower and upper bounds.
303    pub fn bounds(&self) -> (Option<f64>, Option<f64>) {
304        self.0.lock().bounds
305    }
306
307    /// Return the optional unit label.
308    pub fn unit(&self) -> Option<String> {
309        self.0.lock().unit.clone()
310    }
311
312    /// Return the optional LaTeX label.
313    pub fn latex(&self) -> Option<String> {
314        self.0.lock().latex.clone()
315    }
316
317    /// Return the optional human-readable description.
318    pub fn description(&self) -> Option<String> {
319        self.0.lock().description.clone()
320    }
321
322    /// Helper method to set the name of a parameter.
323    fn set_name(&self, name: impl Into<String>) {
324        self.0.lock().name = name.into();
325    }
326
327    /// Helper method to set the fixed value of a parameter.
328    pub fn set_fixed_value(&self, value: Option<f64>) {
329        {
330            let mut guard = self.0.lock();
331            if let Some(value) = value {
332                guard.fixed = Some(value);
333                guard.initial = Some(value);
334            } else {
335                guard.fixed = None;
336                // NOTE: freeing keeps the initial value as the previous fixed value
337            }
338        }
339    }
340
341    /// Helper method to set the initial value of a parameter.
342    ///
343    /// # Panics
344    ///
345    /// This method panics if the parameter is fixed.
346    pub fn set_initial(&self, value: f64) {
347        assert!(
348            self.is_free(),
349            "cannot manually set `initial` on a fixed parameter"
350        );
351        self.0.lock().initial = Some(value);
352    }
353
354    /// Helper method to set the bounds of a parameter.
355    pub fn set_bounds<L, U>(&self, min: L, max: U)
356    where
357        L: IntoBound,
358        U: IntoBound,
359    {
360        self.0.lock().bounds = (IntoBound::into_bound(min), IntoBound::into_bound(max));
361    }
362
363    /// Helper method to set the unit of a parameter.
364    pub fn set_unit(&self, unit: impl Into<String>) {
365        self.0.lock().unit = Some(unit.into());
366    }
367
368    /// Helper method to set the LaTeX representation of a parameter.
369    pub fn set_latex(&self, latex: impl Into<String>) {
370        self.0.lock().latex = Some(latex.into());
371    }
372
373    /// Helper method to set the description of a parameter.
374    pub fn set_description(&self, description: impl Into<String>) {
375        self.0.lock().description = Some(description.into());
376    }
377
378    /// Is this parameter free?
379    pub fn is_free(&self) -> bool {
380        self.0.lock().fixed.is_none()
381    }
382
383    /// Is this parameter fixed?
384    pub fn is_fixed(&self) -> bool {
385        self.0.lock().fixed.is_some()
386    }
387}
388
389/// Convenience macro for creating parameters. Usage:
390/// `parameter!(\"name\")` for a free parameter, or `parameter!(\"name\", 1.0)` for a fixed one.
391#[macro_export]
392macro_rules! parameter {
393    ($name:expr) => {{
394        $crate::amplitudes::Parameter::new($name)
395    }};
396
397    ($name:expr, $value:expr) => {{
398        let p = $crate::amplitudes::Parameter::new($name);
399        p.set_fixed_value(Some($value));
400        p
401    }};
402
403    ($name:expr, $($rest:tt)+) => {{
404        let p = $crate::amplitudes::Parameter::new($name);
405        $crate::parameter!(@parse p, [fixed = false, initial = false]; $($rest)+);
406        p
407    }};
408
409    (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; ) => {};
410
411    (@parse $p:ident, [fixed = false, initial = false]; fixed : $value:expr $(, $($rest:tt)*)?) => {{
412        $p.set_fixed_value(Some($value));
413        $crate::parameter!(@parse $p, [fixed = true, initial = false]; $($($rest)*)?);
414    }};
415
416    (@parse $p:ident, [fixed = false, initial = false]; initial : $value:expr $(, $($rest:tt)*)?) => {{
417        $p.set_initial($value);
418        $crate::parameter!(@parse $p, [fixed = false, initial = true]; $($($rest)*)?);
419    }};
420
421    (@parse $p:ident, [fixed = true, initial = false]; initial : $value:expr $(, $($rest:tt)*)?) => {
422        compile_error!("parameter!: cannot specify both `fixed` and `initial`");
423    };
424
425    (@parse $p:ident, [fixed = false, initial = true]; fixed : $value:expr $(, $($rest:tt)*)?) => {
426        compile_error!("parameter!: cannot specify both `fixed` and `initial`");
427    };
428
429    (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; bounds : ($min:expr, $max:expr) $(, $($rest:tt)*)?) => {{
430        $p.set_bounds($min, $max);
431        $crate::parameter!(@parse $p, [fixed = $f, initial = $i]; $($($rest)*)?);
432    }};
433
434    (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; unit : $value:expr $(, $($rest:tt)*)?) => {{
435        $p.set_unit($value);
436        $crate::parameter!(@parse $p, [fixed = $f, initial = $i]; $($($rest)*)?);
437    }};
438
439    (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; latex : $value:expr $(, $($rest:tt)*)?) => {{
440        $p.set_latex($value);
441        $crate::parameter!(@parse $p, [fixed = $f, initial = $i]; $($($rest)*)?);
442    }};
443
444    (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; description : $value:expr $(, $($rest:tt)*)?) => {{
445        $p.set_description($value);
446        $crate::parameter!(@parse $p, [fixed = $f, initial = $i]; $($($rest)*)?);
447    }};
448}
449
450/// An ordered set of [`Parameter`]s.
451#[derive(Default, Debug, Clone)]
452pub struct ParameterMap {
453    parameters: Vec<Parameter>,
454    name_to_index: HashMap<String, usize>,
455}
456
457#[derive(Serialize, Deserialize)]
458struct ParameterMapSerde {
459    parameters: Vec<Parameter>,
460}
461
462impl Index<usize> for ParameterMap {
463    type Output = Parameter;
464
465    fn index(&self, index: usize) -> &Self::Output {
466        &self.parameters[index]
467    }
468}
469
470impl Index<&str> for ParameterMap {
471    type Output = Parameter;
472
473    fn index(&self, key: &str) -> &Self::Output {
474        self.get(key)
475            .unwrap_or_else(|| panic!("parameter '{key}' not found"))
476    }
477}
478
479impl IntoIterator for ParameterMap {
480    type Item = Parameter;
481
482    type IntoIter = std::vec::IntoIter<Parameter>;
483
484    fn into_iter(self) -> Self::IntoIter {
485        self.parameters.into_iter()
486    }
487}
488
489impl<'a> IntoIterator for &'a ParameterMap {
490    type Item = &'a Parameter;
491
492    type IntoIter = std::slice::Iter<'a, Parameter>;
493
494    fn into_iter(self) -> Self::IntoIter {
495        self.parameters.iter()
496    }
497}
498
499impl Serialize for ParameterMap {
500    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
501    where
502        S: serde::Serializer,
503    {
504        ParameterMapSerde {
505            parameters: self.parameters.clone(),
506        }
507        .serialize(serializer)
508    }
509}
510
511impl<'de> Deserialize<'de> for ParameterMap {
512    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
513    where
514        D: serde::Deserializer<'de>,
515    {
516        let serde = ParameterMapSerde::deserialize(deserializer)?;
517        Ok(Self::from_parameters(serde.parameters))
518    }
519}
520
521impl ParameterMap {
522    fn from_parameters(parameters: Vec<Parameter>) -> Self {
523        let name_to_index = parameters
524            .iter()
525            .enumerate()
526            .map(|(index, parameter)| (parameter.name(), index))
527            .collect();
528        Self {
529            parameters,
530            name_to_index,
531        }
532    }
533
534    /// Register a parameter into the ordered map and return its assembled [`ParameterID`].
535    pub fn register_parameter(&mut self, p: &Parameter) -> LadduResult<ParameterID> {
536        let name = p.name();
537        if name.is_empty() {
538            return Err(LadduError::UnregisteredParameter {
539                name: "<unnamed>".to_string(),
540                reason: "Parameter was not initialized with a name".to_string(),
541            });
542        }
543
544        if let Some((index, existing)) = self.get_indexed(&name) {
545            match (existing.fixed(), p.fixed()) {
546                (Some(a), Some(b)) if (a - b).abs() > f64::EPSILON => {
547                    return Err(LadduError::ParameterConflict {
548                        name,
549                        reason: "conflicting fixed values for the same parameter name".to_string(),
550                    });
551                }
552                (Some(_), None) => {
553                    return Err(LadduError::ParameterConflict {
554                        name,
555                        reason: "attempted to use a fixed parameter name as free".to_string(),
556                    });
557                }
558                (None, Some(_)) => {
559                    return Err(LadduError::ParameterConflict {
560                        name,
561                        reason: "attempted to use a free parameter name as fixed".to_string(),
562                    });
563                }
564                (Some(_), Some(_)) | (None, None) => return Ok(self.parameter_id(index)),
565            }
566        }
567
568        let index = self.parameters.len();
569        self.insert(p.clone());
570        Ok(self.parameter_id(index))
571    }
572    /// Return the assembled indices of all free parameters.
573    pub fn free_parameter_indices(&self) -> Vec<usize> {
574        (0..self.free().len()).collect()
575    }
576    /// Rename a single parameter in place.
577    pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
578        if old == new {
579            return Ok(());
580        }
581        if self.contains_key(new) {
582            return Err(LadduError::ParameterConflict {
583                name: new.to_string(),
584                reason: "rename target already exists".to_string(),
585            });
586        }
587        if let Some(index) = self.index(old) {
588            let parameter = self.parameters[index].clone();
589            parameter.set_name(new);
590            self.name_to_index.remove(old);
591            self.name_to_index.insert(new.to_string(), index);
592        } else {
593            self.assert_parameter_exists(old)?;
594        }
595        Ok(())
596    }
597    /// Rename multiple parameters in place.
598    pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
599        for (old, new) in mapping {
600            self.rename_parameter(old, new)?;
601        }
602        Ok(())
603    }
604    /// Fix a parameter to the supplied value.
605    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
606        self.assert_parameter_exists(name)?;
607        if let Some(parameter) = self.get(name) {
608            parameter.set_fixed_value(Some(value));
609        }
610        Ok(())
611    }
612    /// Mark a parameter as free.
613    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
614        self.assert_parameter_exists(name)?;
615        if let Some(parameter) = self.get(name) {
616            parameter.set_fixed_value(None);
617        }
618        Ok(())
619    }
620    /// Return whether a parameter with the given name exists.
621    pub fn contains_key(&self, name: &str) -> bool {
622        self.name_to_index.contains_key(name)
623    }
624    /// Return the storage index for a named parameter.
625    pub fn index(&self, name: &str) -> Option<usize> {
626        self.name_to_index.get(name).copied()
627    }
628    /// Insert or replace a parameter by name while preserving insertion order.
629    pub fn insert(&mut self, parameter: Parameter) -> Option<Parameter> {
630        let name = parameter.name();
631        if let Some(index) = self.index(&name) {
632            Some(std::mem::replace(&mut self.parameters[index], parameter))
633        } else {
634            let index = self.parameters.len();
635            self.parameters.push(parameter);
636            self.name_to_index.insert(name, index);
637            None
638        }
639    }
640    /// The number of parameters in the set
641    pub fn len(&self) -> usize {
642        self.parameters.len()
643    }
644    /// Returns true if the parameter set has no elements
645    pub fn is_empty(&self) -> bool {
646        self.parameters.is_empty()
647    }
648    /// Iterate over all parameters in the set
649    pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
650        self.parameters.iter()
651    }
652    /// Get a parameter by name
653    pub fn get(&self, key: &str) -> Option<&Parameter> {
654        self.index(key).map(|index| &self.parameters[index])
655    }
656    /// Get both the storage index and parameter for a given name.
657    pub fn get_indexed(&self, key: &str) -> Option<(usize, &Parameter)> {
658        self.index(key)
659            .map(|index| (index, &self.parameters[index]))
660    }
661    /// Get all parameter names in order
662    pub fn names(&self) -> Vec<String> {
663        self.parameters.iter().map(Parameter::name).collect()
664    }
665    /// Filter the parameter set by a predicate
666    pub fn filter(&self, predicate: impl Fn(&Parameter) -> bool) -> Self {
667        Self::from_parameters(
668            self.parameters
669                .iter()
670                .filter(|parameter| predicate(parameter))
671                .cloned()
672                .collect(),
673        )
674    }
675    /// Get a set containing only free parameters
676    pub fn free(&self) -> Self {
677        self.filter(|p| p.is_free())
678    }
679    /// Get a set containing only fixed parameters
680    pub fn fixed(&self) -> Self {
681        self.filter(|p| p.is_fixed())
682    }
683    /// Get a set containing only initialized parameters
684    pub fn initialized(&self) -> Self {
685        self.filter(|p| p.initial().is_some())
686    }
687    /// Get a set containing only uninitialized parameters
688    pub fn uninitialized(&self) -> Self {
689        self.filter(|p| p.initial().is_none())
690    }
691
692    /// Assemble free inputs into a full [`Parameters`] object.
693    ///
694    /// The resulting values are ordered with all free parameters first, followed by fixed ones.
695    pub fn assemble(&self, free_values: &[f64]) -> LadduResult<Parameters> {
696        let expected_free = self.free().len();
697        let n_fixed = self.fixed().len();
698        let mut values = vec![0.0; expected_free + n_fixed];
699        let mut storage_to_assembled = vec![0; self.len()];
700        let mut free_iter = free_values.iter();
701        let mut free_index = 0;
702        let mut fixed_index = expected_free;
703        for (storage_index, parameter) in self.parameters.iter().enumerate() {
704            if let Some(value) = parameter.fixed() {
705                values[fixed_index] = value;
706                storage_to_assembled[storage_index] = fixed_index;
707                fixed_index += 1;
708            } else if let Some(value) = free_iter.next() {
709                values[free_index] = *value;
710                storage_to_assembled[storage_index] = free_index;
711                free_index += 1;
712            } else {
713                return Err(LadduError::LengthMismatch {
714                    context: "parameter values".to_string(),
715                    expected: expected_free,
716                    actual: free_values.len(),
717                });
718            }
719        }
720        if free_iter.next().is_some() {
721            return Err(LadduError::LengthMismatch {
722                context: "parameter values".to_string(),
723                expected: expected_free,
724                actual: free_values.len(),
725            });
726        }
727        Ok(Parameters::new(values, expected_free, storage_to_assembled))
728    }
729
730    /// # Notes
731    /// When parameters overlap, the state and value stored in `self` always take precedence over
732    /// entries from `other`.
733    pub fn merge(&self, other: &Self) -> (Self, Vec<usize>, Vec<usize>) {
734        let mut merged = self.clone();
735        let mut right_map = Vec::with_capacity(other.len());
736        for parameter in other {
737            let idx = merged.ensure_parameter(parameter.clone());
738            right_map.push(idx);
739        }
740        let left_map: Vec<usize> = (0..self.len())
741            .map(|index| merged.assembled_index(index))
742            .collect();
743        let right_map = right_map
744            .into_iter()
745            .map(|index| merged.assembled_index(index))
746            .collect();
747        (merged, left_map, right_map)
748    }
749
750    /// # Notes
751    /// When both managers reference the same parameter, the value and fixed/free status from
752    /// `self` are retained.
753    pub fn extend_from(&self, other: &Self) -> (Self, Vec<usize>) {
754        let mut merged = self.clone();
755        let mut indices = Vec::with_capacity(other.len());
756        for parameter in other {
757            let idx = merged.ensure_parameter(parameter.clone());
758            indices.push(idx);
759        }
760        let indices = indices
761            .into_iter()
762            .map(|index| merged.assembled_index(index))
763            .collect();
764        (merged, indices)
765    }
766
767    fn ensure_parameter(&mut self, parameter: Parameter) -> usize {
768        let name = parameter.name();
769        if let Some(idx) = self.index(&name) {
770            return idx;
771        }
772        let idx = self.len();
773        self.insert(parameter);
774        idx
775    }
776
777    fn assembled_index(&self, storage_index: usize) -> usize {
778        let n_free = self
779            .parameters
780            .iter()
781            .filter(|parameter| parameter.is_free())
782            .count();
783        let preceding_in_group = self.parameters[..storage_index]
784            .iter()
785            .filter(|parameter| self.parameters[storage_index].is_free() == parameter.is_free())
786            .count();
787        if self.parameters[storage_index].is_free() {
788            preceding_in_group
789        } else {
790            n_free + preceding_in_group
791        }
792    }
793
794    fn parameter_id(&self, storage_index: usize) -> ParameterID {
795        if self.parameters[storage_index].is_fixed() {
796            ParameterID::Constant(storage_index)
797        } else {
798            ParameterID::Parameter(storage_index)
799        }
800    }
801
802    fn assert_parameter_exists(&self, name: &str) -> LadduResult<()> {
803        if self.contains_key(name) {
804            Ok(())
805        } else {
806            Err(LadduError::UnregisteredParameter {
807                name: name.to_string(),
808                reason: "parameter not found".to_string(),
809            })
810        }
811    }
812}
813
814/// This is the only required trait for writing new amplitude-like structures for this
815/// crate. Users need only implement the [`register`](Amplitude::register)
816/// method to register parameters, cached values, and the amplitude itself with an input
817/// [`Resources`] struct and the [`compute`](Amplitude::compute) method to actually carry
818/// out the calculation. [`Amplitude`]-implementors are required to implement [`Clone`] and can
819/// optionally implement a [`precompute`](Amplitude::precompute) method to calculate and
820/// cache values which do not depend on free parameters.
821#[typetag::serde(tag = "type")]
822pub trait Amplitude: DynClone + Send + Sync {
823    /// This method should be used to tell the [`Resources`] manager about all of
824    /// the free parameters and cached values used by this [`Amplitude`]. It should end by
825    /// returning an [`AmplitudeID`], which can be obtained from the
826    /// [`Resources::register_amplitude`] method.
827    ///
828    /// [`register`](Amplitude::register) is invoked once when an amplitude is first converted into
829    /// an [`Expression`]. Use it to allocate parameter/cache state within [`Resources`] without assuming
830    /// any dataset context.
831    fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID>;
832    /// Optional semantic identity key for same-name deduplication.
833    ///
834    /// Return `Some` only when two independently constructed instances with equal keys are
835    /// interchangeable after registration, binding, precomputation, value evaluation, and gradient
836    /// evaluation. The key should include the concrete amplitude type and all user-facing
837    /// configuration, but must ignore registration-assigned IDs like [`ParameterID`]s and cache IDs.
838    fn semantic_key(&self) -> Option<AmplitudeSemanticKey> {
839        None
840    }
841    /// Bind this [`Amplitude`] to a concrete [`Dataset`] by using the provided metadata to wire up
842    /// [`Variable`](crate::utils::variables::Variable)s or other dataset-specific state. This will
843    /// be invoked when a [`Expression`] is loaded with data, after [`register`](Amplitude::register)
844    /// has already succeeded. The default implementation is a no-op for amplitudes that do not
845    /// depend on metadata.
846    fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
847        Ok(())
848    }
849    /// Optional dependence hint used by expression-IR diagnostics/planning.
850    ///
851    /// The default returns [`ExpressionDependence::Mixed`] for backward compatibility.
852    fn dependence_hint(&self) -> ExpressionDependence {
853        ExpressionDependence::Mixed
854    }
855    /// Optional hint that this amplitude always evaluates to a purely real complex value.
856    ///
857    /// This must be conservative. Returning `true` allows `expression-ir` to erase
858    /// redundant `imag`, `real`, and `conj` work under the assumption that the
859    /// amplitude output always has zero imaginary component.
860    fn real_valued_hint(&self) -> bool {
861        false
862    }
863    /// This method can be used to do some critical calculations ahead of time and
864    /// store them in a [`Cache`]. These values can only depend on event data,
865    /// not on any free parameters in the fit. This method is opt-in since it is
866    /// not required to make a functioning [`Amplitude`].
867    #[allow(unused_variables)]
868    fn precompute(&self, event: &NamedEventView<'_>, cache: &mut Cache) {}
869
870    /// Evaluate [`Amplitude::precompute`] over columnar event views in a [`Dataset`].
871    #[cfg(feature = "rayon")]
872    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
873        resources
874            .caches
875            .par_iter_mut()
876            .enumerate()
877            .for_each(|(event_index, cache)| {
878                let event = dataset.event_view(event_index);
879                self.precompute(&event, cache);
880            });
881    }
882
883    /// Evaluate [`Amplitude::precompute`] over columnar event views in a [`Dataset`].
884    #[cfg(not(feature = "rayon"))]
885    fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
886        dataset.for_each_named_event_local(|event_index, event| {
887            let cache = &mut resources.caches[event_index];
888            self.precompute(&event, cache);
889        });
890    }
891    /// This method constitutes the main machinery of an [`Amplitude`], returning the actual
892    /// calculated value for a particular set of [`Parameters`] and event [`Cache`].
893    fn compute(&self, parameters: &Parameters, cache: &Cache) -> Complex64;
894
895    /// This method yields the gradient of a particular [`Amplitude`] at a point specified
896    /// by a set of [`Parameters`]. See those structs, as well as
897    /// [`Cache`], for documentation on their available methods. For the most part,
898    /// [`Parameters`] and the [`Cache`] are key-value storage accessed by [`ParameterID`]s and
899    /// several different types of cache
900    /// IDs. If the analytic version of the gradient is known, this method can be overwritten to
901    /// improve performance for some derivative-using methods of minimization. The default
902    /// implementation calculates a central finite difference across all parameters, regardless of
903    /// whether or not they are used in the [`Amplitude`].
904    ///
905    /// In the future, it may be possible to automatically implement this with the indices of
906    /// registered free parameters, but until then, the [`Amplitude::central_difference_with_indices`]
907    /// method can be used to conveniently only calculate central differences for the parameters
908    /// which are used by the [`Amplitude`].
909    fn compute_gradient(
910        &self,
911        parameters: &Parameters,
912        cache: &Cache,
913        gradient: &mut DVector<Complex64>,
914    ) {
915        self.central_difference_with_indices(
916            &Vec::from_iter(0..parameters.len()),
917            parameters,
918            cache,
919            gradient,
920        )
921    }
922
923    /// A helper function to implement a central difference only on indices which correspond to
924    /// free parameters in the [`Amplitude`]. For example, if an [`Amplitude`] contains free
925    /// parameters registered to indices 1, 3, and 5 of the its internal parameters array, then
926    /// running this with those indices will compute a central finite difference derivative for
927    /// those coordinates only, since the rest can be safely assumed to be zero.
928    fn central_difference_with_indices(
929        &self,
930        indices: &[usize],
931        parameters: &Parameters,
932        cache: &Cache,
933        gradient: &mut DVector<Complex64>,
934    ) {
935        let x = parameters.values().to_owned();
936        let h: DVector<f64> = x
937            .iter()
938            .map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
939            .collect::<Vec<_>>()
940            .into();
941        for i in indices {
942            let mut x_plus = x.clone();
943            let mut x_minus = x.clone();
944            x_plus[*i] += h[*i];
945            x_minus[*i] -= h[*i];
946            let f_plus = self.compute(&parameters.with_values(x_plus), cache);
947            let f_minus = self.compute(&parameters.with_values(x_minus), cache);
948            gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
949        }
950    }
951
952    /// Convenience helper to wrap an amplitude into an [`Expression`].
953    ///
954    /// This allows amplitude constructors to return `LadduResult<Expression>` without duplicating
955    /// boxing/registration boilerplate.
956    fn into_expression(self) -> LadduResult<Expression>
957    where
958        Self: Sized + 'static,
959    {
960        Expression::from_amplitude(Box::new(self))
961    }
962}
963
964/// Utility function to calculate a central finite difference gradient.
965pub fn central_difference<F: Fn(&[f64]) -> f64>(parameters: &[f64], func: F) -> DVector<f64> {
966    let mut gradient = DVector::zeros(parameters.len());
967    let x = parameters.to_owned();
968    let h: DVector<f64> = x
969        .iter()
970        .map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
971        .collect::<Vec<_>>()
972        .into();
973    for i in 0..parameters.len() {
974        let mut x_plus = x.clone();
975        let mut x_minus = x.clone();
976        x_plus[i] += h[i];
977        x_minus[i] -= h[i];
978        let f_plus = func(&x_plus);
979        let f_minus = func(&x_minus);
980        gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
981    }
982    gradient
983}
984
985dyn_clone::clone_trait_object!(Amplitude);
986
987/// A helper struct that contains the value of each amplitude for a particular event
988#[derive(Debug)]
989pub struct AmplitudeValues(pub Vec<Complex64>);
990
991/// A helper struct that contains the gradient of each amplitude for a particular event
992#[derive(Debug)]
993pub struct GradientValues(pub usize, pub Vec<DVector<Complex64>>);
994
995/// A tag which refers to a registered [`Amplitude`]. This is the base object which can be used to
996/// build [`Expression`]s and should be obtained from the [`Resources::register_amplitude`] method.
997#[derive(Clone, Default, Debug, Serialize, Deserialize)]
998pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
999
1000impl Display for AmplitudeID {
1001    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1002        write!(f, "{}(id={})", self.0, self.1)
1003    }
1004}
1005
1006/// A single named field in an [`AmplitudeSemanticKey`].
1007#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
1008pub struct AmplitudeSemanticField {
1009    name: String,
1010    value: String,
1011}
1012
1013impl AmplitudeSemanticField {
1014    /// Construct a semantic key field.
1015    pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
1016        Self {
1017            name: name.into(),
1018            value: value.into(),
1019        }
1020    }
1021}
1022
1023/// A semantic identity key used to opt into deduplicating same-name amplitude instances.
1024///
1025/// The key must include enough type/configuration information to prove that two independently
1026/// constructed amplitudes with the same public name can safely share one registered amplitude.
1027#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
1028pub struct AmplitudeSemanticKey {
1029    kind: String,
1030    fields: Vec<AmplitudeSemanticField>,
1031}
1032
1033impl AmplitudeSemanticKey {
1034    /// Construct a semantic key for the given amplitude kind.
1035    pub fn new(kind: impl Into<String>) -> Self {
1036        Self {
1037            kind: kind.into(),
1038            fields: Vec::new(),
1039        }
1040    }
1041
1042    /// Add a named field to this semantic key.
1043    pub fn with_field(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
1044        self.fields.push(AmplitudeSemanticField::new(name, value));
1045        self
1046    }
1047
1048    fn field_value(&self, name: &str) -> Option<&str> {
1049        self.fields
1050            .iter()
1051            .find(|field| field.name == name)
1052            .map(|field| field.value.as_str())
1053    }
1054
1055    fn mismatch_summary(&self, other: &Self) -> String {
1056        let mut mismatches = Vec::new();
1057        if self.kind != other.kind {
1058            mismatches.push(format!(
1059                "kind differs (existing: {:?}, incoming: {:?})",
1060                self.kind, other.kind
1061            ));
1062        }
1063        for field in &self.fields {
1064            match other.field_value(&field.name) {
1065                Some(value) if value == field.value => {}
1066                Some(value) => mismatches.push(format!(
1067                    "{} differs (existing: {}, incoming: {})",
1068                    field.name, field.value, value
1069                )),
1070                None => mismatches.push(format!(
1071                    "{} is missing from the incoming key (existing: {})",
1072                    field.name, field.value
1073                )),
1074            }
1075        }
1076        for field in &other.fields {
1077            if self.field_value(&field.name).is_none() {
1078                mismatches.push(format!(
1079                    "{} is missing from the existing key (incoming: {})",
1080                    field.name, field.value
1081                ));
1082            }
1083        }
1084        if mismatches.is_empty() {
1085            "semantic keys differ".to_string()
1086        } else {
1087            mismatches.join("; ")
1088        }
1089    }
1090}
1091
1092/// A holder struct that owns both an expression tree and the registered amplitudes.
1093#[allow(missing_docs)]
1094#[derive(Clone, Serialize, Deserialize)]
1095pub struct Expression {
1096    registry: ExpressionRegistry,
1097    tree: ExpressionNode,
1098}
1099
1100impl ReadWrite for Expression {
1101    fn create_null() -> Self {
1102        Self {
1103            registry: ExpressionRegistry::default(),
1104            tree: ExpressionNode::default(),
1105        }
1106    }
1107}
1108
1109#[derive(Clone, Serialize, Deserialize)]
1110#[allow(missing_docs)]
1111#[derive(Default)]
1112pub struct ExpressionRegistry {
1113    amplitudes: Vec<Box<dyn Amplitude>>,
1114    amplitude_names: Vec<String>,
1115    amplitude_ids: Vec<u64>,
1116    resources: Resources,
1117}
1118
1119impl ExpressionRegistry {
1120    fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
1121        let mut resources = Resources::default();
1122        let aid = amplitude.register(&mut resources)?;
1123        let amp_id = next_amplitude_id();
1124        Ok(Self {
1125            amplitudes: vec![amplitude],
1126            amplitude_names: vec![aid.0],
1127            amplitude_ids: vec![amp_id],
1128            resources,
1129        })
1130    }
1131
1132    fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
1133        let mut resources = Resources::default();
1134        let mut amplitudes = Vec::new();
1135        let mut amplitude_names = Vec::new();
1136        let mut amplitude_ids = Vec::new();
1137        let mut amplitude_semantic_keys = Vec::new();
1138        let mut name_to_index = HashMap::new();
1139
1140        let mut left_map = Vec::with_capacity(self.amplitudes.len());
1141        for ((amp, name), amp_id) in self
1142            .amplitudes
1143            .iter()
1144            .zip(&self.amplitude_names)
1145            .zip(&self.amplitude_ids)
1146        {
1147            let semantic_key = amp.semantic_key();
1148            let mut cloned_amp = dyn_clone::clone_box(&**amp);
1149            let aid = cloned_amp.register(&mut resources)?;
1150            amplitudes.push(cloned_amp);
1151            amplitude_names.push(name.clone());
1152            amplitude_ids.push(*amp_id);
1153            amplitude_semantic_keys.push(semantic_key);
1154            name_to_index.insert(name.clone(), aid.1);
1155            left_map.push(aid.1);
1156        }
1157
1158        let mut right_map = Vec::with_capacity(other.amplitudes.len());
1159        for ((amp, name), amp_id) in other
1160            .amplitudes
1161            .iter()
1162            .zip(&other.amplitude_names)
1163            .zip(&other.amplitude_ids)
1164        {
1165            if let Some(existing) = name_to_index.get(name) {
1166                let existing_amp_id = amplitude_ids[*existing];
1167                let incoming_semantic_key = amp.semantic_key();
1168                if existing_amp_id != *amp_id {
1169                    match (&amplitude_semantic_keys[*existing], &incoming_semantic_key) {
1170                        (Some(existing_key), Some(incoming_key))
1171                            if existing_key == incoming_key => {}
1172                        (Some(existing_key), Some(incoming_key)) => {
1173                            return Err(LadduError::Custom(format!(
1174                                "Amplitude name \"{name}\" refers to different underlying amplitudes; semantic keys differ: {}",
1175                                existing_key.mismatch_summary(incoming_key)
1176                            )));
1177                        }
1178                        _ => {
1179                            return Err(LadduError::Custom(format!(
1180                                "Amplitude name \"{name}\" refers to different underlying amplitudes; rename to avoid conflicts"
1181                            )));
1182                        }
1183                    }
1184                }
1185                right_map.push(*existing);
1186                continue;
1187            }
1188            let semantic_key = amp.semantic_key();
1189            let mut cloned_amp = dyn_clone::clone_box(&**amp);
1190            let aid = cloned_amp.register(&mut resources)?;
1191            amplitudes.push(cloned_amp);
1192            amplitude_names.push(name.clone());
1193            amplitude_ids.push(*amp_id);
1194            amplitude_semantic_keys.push(semantic_key);
1195            name_to_index.insert(name.clone(), aid.1);
1196            right_map.push(aid.1);
1197        }
1198
1199        Ok((
1200            Self {
1201                amplitudes,
1202                amplitude_names,
1203                amplitude_ids,
1204                resources,
1205            },
1206            left_map,
1207            right_map,
1208        ))
1209    }
1210}
1211
1212/// Expression tree used by [`Expression`].
1213#[allow(missing_docs)]
1214#[derive(Clone, Serialize, Deserialize, Default, Debug)]
1215pub enum ExpressionNode {
1216    #[default]
1217    /// A expression equal to zero.
1218    Zero,
1219    /// A expression equal to one.
1220    One,
1221    /// A real-valued constant.
1222    Constant(f64),
1223    /// A complex-valued constant.
1224    ComplexConstant(Complex64),
1225    /// A registered [`Amplitude`] referenced by index.
1226    Amp(usize),
1227    /// The sum of two [`ExpressionNode`]s.
1228    Add(Box<ExpressionNode>, Box<ExpressionNode>),
1229    /// The difference of two [`ExpressionNode`]s.
1230    Sub(Box<ExpressionNode>, Box<ExpressionNode>),
1231    /// The product of two [`ExpressionNode`]s.
1232    Mul(Box<ExpressionNode>, Box<ExpressionNode>),
1233    /// The division of two [`ExpressionNode`]s.
1234    Div(Box<ExpressionNode>, Box<ExpressionNode>),
1235    /// The additive inverse of an [`ExpressionNode`].
1236    Neg(Box<ExpressionNode>),
1237    /// The real part of an [`ExpressionNode`].
1238    Real(Box<ExpressionNode>),
1239    /// The imaginary part of an [`ExpressionNode`].
1240    Imag(Box<ExpressionNode>),
1241    /// The complex conjugate of an [`ExpressionNode`].
1242    Conj(Box<ExpressionNode>),
1243    /// The absolute square of an [`ExpressionNode`].
1244    NormSqr(Box<ExpressionNode>),
1245    Sqrt(Box<ExpressionNode>),
1246    Pow(Box<ExpressionNode>, Box<ExpressionNode>),
1247    PowI(Box<ExpressionNode>, i32),
1248    PowF(Box<ExpressionNode>, f64),
1249    Exp(Box<ExpressionNode>),
1250    Sin(Box<ExpressionNode>),
1251    Cos(Box<ExpressionNode>),
1252    Log(Box<ExpressionNode>),
1253    Cis(Box<ExpressionNode>),
1254}
1255
1256#[derive(Clone, Debug)]
1257/// Standalone bytecode executor compiled directly from the semantic expression tree.
1258///
1259/// This is retained for direct tree-level helpers on [`ExpressionNode`] and debugging of the
1260/// unfactored semantic expression shape. It is intentionally distinct from the lowered runtime:
1261/// current lowering carries slot reuse, peephole rewrites, root-specific lowering, and
1262/// specialized normalization helpers that would be awkward to force back into this form.
1263struct ExpressionProgram {
1264    ops: Vec<ExpressionOp>,
1265    slot_count: usize,
1266    root_slot: usize,
1267}
1268
1269#[derive(Clone, Debug)]
1270enum ExpressionOp {
1271    LoadZero {
1272        dst: usize,
1273    },
1274    LoadOne {
1275        dst: usize,
1276    },
1277    LoadConstant {
1278        dst: usize,
1279        value: f64,
1280    },
1281    LoadComplexConstant {
1282        dst: usize,
1283        value: Complex64,
1284    },
1285    LoadAmp {
1286        dst: usize,
1287        amp_idx: usize,
1288    },
1289    Add {
1290        dst: usize,
1291        left: usize,
1292        right: usize,
1293    },
1294    Sub {
1295        dst: usize,
1296        left: usize,
1297        right: usize,
1298    },
1299    Mul {
1300        dst: usize,
1301        left: usize,
1302        right: usize,
1303    },
1304    Div {
1305        dst: usize,
1306        left: usize,
1307        right: usize,
1308    },
1309    Neg {
1310        dst: usize,
1311        input: usize,
1312    },
1313    Real {
1314        dst: usize,
1315        input: usize,
1316    },
1317    Imag {
1318        dst: usize,
1319        input: usize,
1320    },
1321    Conj {
1322        dst: usize,
1323        input: usize,
1324    },
1325    NormSqr {
1326        dst: usize,
1327        input: usize,
1328    },
1329    Sqrt {
1330        dst: usize,
1331        input: usize,
1332    },
1333    Pow {
1334        dst: usize,
1335        value: usize,
1336        power: usize,
1337    },
1338    PowI {
1339        dst: usize,
1340        input: usize,
1341        power: i32,
1342    },
1343    PowF {
1344        dst: usize,
1345        input: usize,
1346        power: f64,
1347    },
1348    Exp {
1349        dst: usize,
1350        input: usize,
1351    },
1352    Sin {
1353        dst: usize,
1354        input: usize,
1355    },
1356    Cos {
1357        dst: usize,
1358        input: usize,
1359    },
1360    Log {
1361        dst: usize,
1362        input: usize,
1363    },
1364    Cis {
1365        dst: usize,
1366        input: usize,
1367    },
1368}
1369
1370#[derive(Default)]
1371struct ExpressionProgramBuilder {
1372    ops: Vec<ExpressionOp>,
1373    next_slot: usize,
1374}
1375
1376impl ExpressionProgramBuilder {
1377    fn alloc_slot(&mut self) -> usize {
1378        let slot = self.next_slot;
1379        self.next_slot += 1;
1380        slot
1381    }
1382
1383    fn build(self, root: usize) -> ExpressionProgram {
1384        ExpressionProgram {
1385            ops: self.ops,
1386            slot_count: self.next_slot,
1387            root_slot: root,
1388        }
1389    }
1390
1391    fn emit(&mut self, op: ExpressionOp) {
1392        self.ops.push(op);
1393    }
1394
1395    fn compile(&mut self, node: &ExpressionNode) -> usize {
1396        match node {
1397            ExpressionNode::Zero => {
1398                let dst = self.alloc_slot();
1399                self.emit(ExpressionOp::LoadZero { dst });
1400                dst
1401            }
1402            ExpressionNode::One => {
1403                let dst = self.alloc_slot();
1404                self.emit(ExpressionOp::LoadOne { dst });
1405                dst
1406            }
1407            ExpressionNode::Constant(value) => {
1408                let dst = self.alloc_slot();
1409                self.emit(ExpressionOp::LoadConstant { dst, value: *value });
1410                dst
1411            }
1412            ExpressionNode::ComplexConstant(value) => {
1413                let dst = self.alloc_slot();
1414                self.emit(ExpressionOp::LoadComplexConstant { dst, value: *value });
1415                dst
1416            }
1417            ExpressionNode::Amp(idx) => {
1418                let dst = self.alloc_slot();
1419                self.emit(ExpressionOp::LoadAmp { dst, amp_idx: *idx });
1420                dst
1421            }
1422            ExpressionNode::Add(a, b) => {
1423                let left = self.compile(a);
1424                let right = self.compile(b);
1425                let dst = self.alloc_slot();
1426                self.emit(ExpressionOp::Add { dst, left, right });
1427                dst
1428            }
1429            ExpressionNode::Sub(a, b) => {
1430                let left = self.compile(a);
1431                let right = self.compile(b);
1432                let dst = self.alloc_slot();
1433                self.emit(ExpressionOp::Sub { dst, left, right });
1434                dst
1435            }
1436            ExpressionNode::Mul(a, b) => {
1437                let left = self.compile(a);
1438                let right = self.compile(b);
1439                let dst = self.alloc_slot();
1440                self.emit(ExpressionOp::Mul { dst, left, right });
1441                dst
1442            }
1443            ExpressionNode::Div(a, b) => {
1444                let left = self.compile(a);
1445                let right = self.compile(b);
1446                let dst = self.alloc_slot();
1447                self.emit(ExpressionOp::Div { dst, left, right });
1448                dst
1449            }
1450            ExpressionNode::Neg(a) => {
1451                let input = self.compile(a);
1452                let dst = self.alloc_slot();
1453                self.emit(ExpressionOp::Neg { dst, input });
1454                dst
1455            }
1456            ExpressionNode::Real(a) => {
1457                let input = self.compile(a);
1458                let dst = self.alloc_slot();
1459                self.emit(ExpressionOp::Real { dst, input });
1460                dst
1461            }
1462            ExpressionNode::Imag(a) => {
1463                let input = self.compile(a);
1464                let dst = self.alloc_slot();
1465                self.emit(ExpressionOp::Imag { dst, input });
1466                dst
1467            }
1468            ExpressionNode::Conj(a) => {
1469                let input = self.compile(a);
1470                let dst = self.alloc_slot();
1471                self.emit(ExpressionOp::Conj { dst, input });
1472                dst
1473            }
1474            ExpressionNode::NormSqr(a) => {
1475                let input = self.compile(a);
1476                let dst = self.alloc_slot();
1477                self.emit(ExpressionOp::NormSqr { dst, input });
1478                dst
1479            }
1480            ExpressionNode::Sqrt(a) => {
1481                let input = self.compile(a);
1482                let dst = self.alloc_slot();
1483                self.emit(ExpressionOp::Sqrt { dst, input });
1484                dst
1485            }
1486            ExpressionNode::Pow(a, b) => {
1487                let value = self.compile(a);
1488                let power = self.compile(b);
1489                let dst = self.alloc_slot();
1490                self.emit(ExpressionOp::Pow { dst, value, power });
1491                dst
1492            }
1493            ExpressionNode::PowI(a, power) => {
1494                let input = self.compile(a);
1495                let dst = self.alloc_slot();
1496                self.emit(ExpressionOp::PowI {
1497                    dst,
1498                    input,
1499                    power: *power,
1500                });
1501                dst
1502            }
1503            ExpressionNode::PowF(a, power) => {
1504                let input = self.compile(a);
1505                let dst = self.alloc_slot();
1506                self.emit(ExpressionOp::PowF {
1507                    dst,
1508                    input,
1509                    power: *power,
1510                });
1511                dst
1512            }
1513            ExpressionNode::Exp(a) => {
1514                let input = self.compile(a);
1515                let dst = self.alloc_slot();
1516                self.emit(ExpressionOp::Exp { dst, input });
1517                dst
1518            }
1519            ExpressionNode::Sin(a) => {
1520                let input = self.compile(a);
1521                let dst = self.alloc_slot();
1522                self.emit(ExpressionOp::Sin { dst, input });
1523                dst
1524            }
1525            ExpressionNode::Cos(a) => {
1526                let input = self.compile(a);
1527                let dst = self.alloc_slot();
1528                self.emit(ExpressionOp::Cos { dst, input });
1529                dst
1530            }
1531            ExpressionNode::Log(a) => {
1532                let input = self.compile(a);
1533                let dst = self.alloc_slot();
1534                self.emit(ExpressionOp::Log { dst, input });
1535                dst
1536            }
1537            ExpressionNode::Cis(a) => {
1538                let input = self.compile(a);
1539                let dst = self.alloc_slot();
1540                self.emit(ExpressionOp::Cis { dst, input });
1541                dst
1542            }
1543        }
1544    }
1545}
1546
1547impl ExpressionProgram {
1548    fn from_node(node: &ExpressionNode) -> Self {
1549        let mut builder = ExpressionProgramBuilder::default();
1550        let root = builder.compile(node);
1551        builder.build(root)
1552    }
1553
1554    fn fill_values(&self, amplitude_values: &[Complex64], slots: &mut [Complex64]) {
1555        debug_assert!(slots.len() >= self.slot_count);
1556        for op in &self.ops {
1557            match *op {
1558                ExpressionOp::LoadZero { dst } => slots[dst] = Complex64::ZERO,
1559                ExpressionOp::LoadOne { dst } => slots[dst] = Complex64::ONE,
1560                ExpressionOp::LoadConstant { dst, value } => slots[dst] = Complex64::from(value),
1561                ExpressionOp::LoadComplexConstant { dst, value } => slots[dst] = value,
1562                ExpressionOp::LoadAmp { dst, amp_idx } => {
1563                    slots[dst] = amplitude_values.get(amp_idx).copied().unwrap_or_default();
1564                }
1565                ExpressionOp::Add { dst, left, right } => {
1566                    slots[dst] = slots[left] + slots[right];
1567                }
1568                ExpressionOp::Sub { dst, left, right } => {
1569                    slots[dst] = slots[left] - slots[right];
1570                }
1571                ExpressionOp::Mul { dst, left, right } => {
1572                    slots[dst] = slots[left] * slots[right];
1573                }
1574                ExpressionOp::Div { dst, left, right } => {
1575                    slots[dst] = slots[left] / slots[right];
1576                }
1577                ExpressionOp::Neg { dst, input } => {
1578                    slots[dst] = -slots[input];
1579                }
1580                ExpressionOp::Real { dst, input } => {
1581                    slots[dst] = Complex64::new(slots[input].re, 0.0);
1582                }
1583                ExpressionOp::Imag { dst, input } => {
1584                    slots[dst] = Complex64::new(slots[input].im, 0.0);
1585                }
1586                ExpressionOp::Conj { dst, input } => {
1587                    slots[dst] = slots[input].conj();
1588                }
1589                ExpressionOp::NormSqr { dst, input } => {
1590                    slots[dst] = Complex64::new(slots[input].norm_sqr(), 0.0);
1591                }
1592                ExpressionOp::Sqrt { dst, input } => {
1593                    slots[dst] = slots[input].sqrt();
1594                }
1595                ExpressionOp::Pow { dst, value, power } => {
1596                    slots[dst] = slots[value].powc(slots[power]);
1597                }
1598                ExpressionOp::PowI { dst, input, power } => {
1599                    slots[dst] = slots[input].powi(power);
1600                }
1601                ExpressionOp::PowF { dst, input, power } => {
1602                    slots[dst] = slots[input].powc(Complex64::new(power, 0.0));
1603                }
1604                ExpressionOp::Exp { dst, input } => {
1605                    slots[dst] = slots[input].exp();
1606                }
1607                ExpressionOp::Sin { dst, input } => {
1608                    slots[dst] = slots[input].sin();
1609                }
1610                ExpressionOp::Cos { dst, input } => {
1611                    slots[dst] = slots[input].cos();
1612                }
1613                ExpressionOp::Log { dst, input } => {
1614                    slots[dst] = slots[input].ln();
1615                }
1616                ExpressionOp::Cis { dst, input } => {
1617                    slots[dst] = (Complex64::new(0.0, 1.0) * slots[input]).exp();
1618                }
1619            }
1620        }
1621    }
1622
1623    fn evaluate_into(&self, amplitude_values: &[Complex64], slots: &mut [Complex64]) -> Complex64 {
1624        if self.slot_count == 0 {
1625            return Complex64::ZERO;
1626        }
1627        self.fill_values(amplitude_values, slots);
1628        slots[self.root_slot]
1629    }
1630
1631    pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
1632        if self.slot_count == 0 {
1633            return Complex64::ZERO;
1634        }
1635        let mut slots = vec![Complex64::ZERO; self.slot_count];
1636        self.evaluate_into(amplitude_values, &mut slots)
1637    }
1638
1639    pub fn evaluate_gradient_into(
1640        &self,
1641        amplitude_values: &[Complex64],
1642        gradient_values: &[DVector<Complex64>],
1643        value_slots: &mut [Complex64],
1644        gradient_slots: &mut [DVector<Complex64>],
1645    ) -> DVector<Complex64> {
1646        if self.slot_count == 0 {
1647            let dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
1648            return DVector::zeros(dim);
1649        }
1650        self.fill_values(amplitude_values, value_slots);
1651        self.fill_gradients(gradient_values, value_slots, gradient_slots);
1652        gradient_slots[self.root_slot].clone()
1653    }
1654
1655    pub fn evaluate_gradient(
1656        &self,
1657        amplitude_values: &[Complex64],
1658        gradient_values: &[DVector<Complex64>],
1659    ) -> DVector<Complex64> {
1660        let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
1661        let mut value_slots = vec![Complex64::ZERO; self.slot_count];
1662        let mut gradient_slots: Vec<DVector<Complex64>> = (0..self.slot_count)
1663            .map(|_| DVector::zeros(grad_dim))
1664            .collect();
1665        self.evaluate_gradient_into(
1666            amplitude_values,
1667            gradient_values,
1668            &mut value_slots,
1669            &mut gradient_slots,
1670        )
1671    }
1672
1673    fn fill_gradients(
1674        &self,
1675        amplitude_gradients: &[DVector<Complex64>],
1676        values: &[Complex64],
1677        gradients: &mut [DVector<Complex64>],
1678    ) {
1679        debug_assert!(gradients.len() >= self.slot_count);
1680        debug_assert!(values.len() >= self.slot_count);
1681        fn borrow_dst(
1682            gradients: &mut [DVector<Complex64>],
1683            dst: usize,
1684        ) -> (&[DVector<Complex64>], &mut DVector<Complex64>) {
1685            let (before, tail) = gradients.split_at_mut(dst);
1686            let (dst_ref, _) = tail.split_first_mut().expect("dst slot should exist");
1687            (before, dst_ref)
1688        }
1689        for op in &self.ops {
1690            match *op {
1691                ExpressionOp::LoadZero { dst }
1692                | ExpressionOp::LoadOne { dst }
1693                | ExpressionOp::LoadConstant { dst, .. }
1694                | ExpressionOp::LoadComplexConstant { dst, .. } => {
1695                    let (_, dst_grad) = borrow_dst(gradients, dst);
1696                    for item in dst_grad.iter_mut() {
1697                        *item = Complex64::ZERO;
1698                    }
1699                }
1700                ExpressionOp::LoadAmp { dst, amp_idx } => {
1701                    let (_, dst_grad) = borrow_dst(gradients, dst);
1702                    if let Some(source) = amplitude_gradients.get(amp_idx) {
1703                        dst_grad.clone_from(source);
1704                    } else {
1705                        for item in dst_grad.iter_mut() {
1706                            *item = Complex64::ZERO;
1707                        }
1708                    }
1709                }
1710                ExpressionOp::Add { dst, left, right } => {
1711                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1712                    dst_grad.clone_from(&before_dst[left]);
1713                    for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
1714                    {
1715                        *dst_item += *right_item;
1716                    }
1717                }
1718                ExpressionOp::Sub { dst, left, right } => {
1719                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1720                    dst_grad.clone_from(&before_dst[left]);
1721                    for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
1722                    {
1723                        *dst_item -= *right_item;
1724                    }
1725                }
1726                ExpressionOp::Mul { dst, left, right } => {
1727                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1728                    let f_left = values[left];
1729                    let f_right = values[right];
1730                    dst_grad.clone_from(&before_dst[right]);
1731                    for item in dst_grad.iter_mut() {
1732                        *item *= f_left;
1733                    }
1734                    for (dst_item, left_item) in dst_grad.iter_mut().zip(before_dst[left].iter()) {
1735                        *dst_item += *left_item * f_right;
1736                    }
1737                }
1738                ExpressionOp::Div { dst, left, right } => {
1739                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1740                    let f_left = values[left];
1741                    let f_right = values[right];
1742                    let denom = f_right * f_right;
1743                    dst_grad.clone_from(&before_dst[left]);
1744                    for item in dst_grad.iter_mut() {
1745                        *item *= f_right;
1746                    }
1747                    for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
1748                    {
1749                        *dst_item -= *right_item * f_left;
1750                    }
1751                    for item in dst_grad.iter_mut() {
1752                        *item /= denom;
1753                    }
1754                }
1755                ExpressionOp::Neg { dst, input } => {
1756                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1757                    dst_grad.clone_from(&before_dst[input]);
1758                    for item in dst_grad.iter_mut() {
1759                        *item = -*item;
1760                    }
1761                }
1762                ExpressionOp::Real { dst, input } => {
1763                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1764                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1765                    {
1766                        *dst_item = Complex64::new(input_item.re, 0.0);
1767                    }
1768                }
1769                ExpressionOp::Imag { dst, input } => {
1770                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1771                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1772                    {
1773                        *dst_item = Complex64::new(input_item.im, 0.0);
1774                    }
1775                }
1776                ExpressionOp::Conj { dst, input } => {
1777                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1778                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1779                    {
1780                        *dst_item = input_item.conj();
1781                    }
1782                }
1783                ExpressionOp::NormSqr { dst, input } => {
1784                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1785                    let conj_value = values[input].conj();
1786                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1787                    {
1788                        *dst_item = Complex64::new(2.0 * (*input_item * conj_value).re, 0.0);
1789                    }
1790                }
1791                ExpressionOp::Sqrt { dst, input } => {
1792                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1793                    let factor = Complex64::new(0.5, 0.0) / values[input].sqrt();
1794                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1795                    {
1796                        *dst_item = *input_item * factor;
1797                    }
1798                }
1799                ExpressionOp::Pow { dst, value, power } => {
1800                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1801                    let base = values[value];
1802                    let exponent = values[power];
1803                    let output = values[dst];
1804                    for ((dst_item, value_item), power_item) in dst_grad
1805                        .iter_mut()
1806                        .zip(before_dst[value].iter())
1807                        .zip(before_dst[power].iter())
1808                    {
1809                        *dst_item =
1810                            output * (*power_item * base.ln() + exponent * *value_item / base);
1811                    }
1812                }
1813                ExpressionOp::PowI { dst, input, power } => {
1814                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1815                    let factor = match power {
1816                        0 => Complex64::ZERO,
1817                        1 => Complex64::ONE,
1818                        _ => {
1819                            let base = values[input];
1820                            let multiplier = Complex64::new(power as f64, 0.0);
1821                            if let Some(derivative_power) = power.checked_sub(1) {
1822                                multiplier * base.powi(derivative_power)
1823                            } else {
1824                                multiplier * base.powi(power) / base
1825                            }
1826                        }
1827                    };
1828                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1829                    {
1830                        *dst_item = *input_item * factor;
1831                    }
1832                }
1833                ExpressionOp::PowF { dst, input, power } => {
1834                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1835                    let factor = if power == 0.0 {
1836                        Complex64::ZERO
1837                    } else {
1838                        Complex64::new(power, 0.0)
1839                            * values[input].powc(Complex64::new(power - 1.0, 0.0))
1840                    };
1841                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1842                    {
1843                        *dst_item = *input_item * factor;
1844                    }
1845                }
1846                ExpressionOp::Exp { dst, input } => {
1847                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1848                    let output = values[dst];
1849                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1850                    {
1851                        *dst_item = *input_item * output;
1852                    }
1853                }
1854                ExpressionOp::Sin { dst, input } => {
1855                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1856                    let factor = values[input].cos();
1857                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1858                    {
1859                        *dst_item = *input_item * factor;
1860                    }
1861                }
1862                ExpressionOp::Cos { dst, input } => {
1863                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1864                    let factor = -values[input].sin();
1865                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1866                    {
1867                        *dst_item = *input_item * factor;
1868                    }
1869                }
1870                ExpressionOp::Log { dst, input } => {
1871                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1872                    let factor = Complex64::ONE / values[input];
1873                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1874                    {
1875                        *dst_item = *input_item * factor;
1876                    }
1877                }
1878                ExpressionOp::Cis { dst, input } => {
1879                    let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1880                    let factor = Complex64::new(0.0, 1.0) * values[dst];
1881                    for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1882                    {
1883                        *dst_item = *input_item * factor;
1884                    }
1885                }
1886            }
1887        }
1888    }
1889}
1890
1891impl ExpressionNode {
1892    fn remap(&self, mapping: &[usize]) -> Self {
1893        match self {
1894            Self::Amp(idx) => Self::Amp(mapping[*idx]),
1895            Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1896            Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1897            Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1898            Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1899            Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
1900            Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
1901            Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
1902            Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
1903            Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
1904            Self::Zero => Self::Zero,
1905            Self::One => Self::One,
1906            Self::Constant(v) => Self::Constant(*v),
1907            Self::ComplexConstant(v) => Self::ComplexConstant(*v),
1908            Self::Sqrt(a) => Self::Sqrt(Box::new(a.remap(mapping))),
1909            Self::Pow(a, b) => Self::Pow(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1910            Self::PowI(a, power) => Self::PowI(Box::new(a.remap(mapping)), *power),
1911            Self::PowF(a, power) => Self::PowF(Box::new(a.remap(mapping)), *power),
1912            Self::Exp(a) => Self::Exp(Box::new(a.remap(mapping))),
1913            Self::Sin(a) => Self::Sin(Box::new(a.remap(mapping))),
1914            Self::Cos(a) => Self::Cos(Box::new(a.remap(mapping))),
1915            Self::Log(a) => Self::Log(Box::new(a.remap(mapping))),
1916            Self::Cis(a) => Self::Cis(Box::new(a.remap(mapping))),
1917        }
1918    }
1919
1920    fn program(&self) -> ExpressionProgram {
1921        ExpressionProgram::from_node(self)
1922    }
1923
1924    /// Evaluate an [`ExpressionNode`] by compiling it to bytecode on the fly.
1925    pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
1926        self.program().evaluate(amplitude_values)
1927    }
1928
1929    /// Evaluate the gradient of an [`ExpressionNode`].
1930    pub fn evaluate_gradient(
1931        &self,
1932        amplitude_values: &[Complex64],
1933        gradient_values: &[DVector<Complex64>],
1934    ) -> DVector<Complex64> {
1935        self.program()
1936            .evaluate_gradient(amplitude_values, gradient_values)
1937    }
1938}
1939
1940impl From<f64> for Expression {
1941    fn from(value: f64) -> Self {
1942        if value == 0.0 {
1943            Self {
1944                registry: ExpressionRegistry::default(),
1945                tree: ExpressionNode::Zero,
1946            }
1947        } else if value == 1.0 {
1948            Self {
1949                registry: ExpressionRegistry::default(),
1950                tree: ExpressionNode::One,
1951            }
1952        } else {
1953            Self {
1954                registry: ExpressionRegistry::default(),
1955                tree: ExpressionNode::Constant(value),
1956            }
1957        }
1958    }
1959}
1960impl From<&f64> for Expression {
1961    fn from(value: &f64) -> Self {
1962        (*value).into()
1963    }
1964}
1965impl From<Complex64> for Expression {
1966    fn from(value: Complex64) -> Self {
1967        if value == Complex64::ZERO {
1968            Self {
1969                registry: ExpressionRegistry::default(),
1970                tree: ExpressionNode::Zero,
1971            }
1972        } else if value == Complex64::ONE {
1973            Self {
1974                registry: ExpressionRegistry::default(),
1975                tree: ExpressionNode::One,
1976            }
1977        } else {
1978            Self {
1979                registry: ExpressionRegistry::default(),
1980                tree: ExpressionNode::ComplexConstant(value),
1981            }
1982        }
1983    }
1984}
1985impl From<&Complex64> for Expression {
1986    fn from(value: &Complex64) -> Self {
1987        (*value).into()
1988    }
1989}
1990
1991impl Expression {
1992    /// Build an [`Expression`] from a single [`Amplitude`].
1993    pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
1994        let registry = ExpressionRegistry::singleton(amplitude)?;
1995        Ok(Self {
1996            tree: ExpressionNode::Amp(0),
1997            registry,
1998        })
1999    }
2000
2001    /// Create an expression representing zero, the additive identity.
2002    pub fn zero() -> Self {
2003        Self {
2004            registry: ExpressionRegistry::default(),
2005            tree: ExpressionNode::Zero,
2006        }
2007    }
2008
2009    /// Create an expression representing one, the multiplicative identity.
2010    pub fn one() -> Self {
2011        Self {
2012            registry: ExpressionRegistry::default(),
2013            tree: ExpressionNode::One,
2014        }
2015    }
2016
2017    fn binary_op(
2018        a: &Expression,
2019        b: &Expression,
2020        build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
2021    ) -> Expression {
2022        let (registry, left_map, right_map) = a
2023            .registry
2024            .merge(&b.registry)
2025            .expect("merging expression registries should not fail");
2026        let left_tree = a.tree.remap(&left_map);
2027        let right_tree = b.tree.remap(&right_map);
2028        Expression {
2029            registry,
2030            tree: build(Box::new(left_tree), Box::new(right_tree)),
2031        }
2032    }
2033
2034    fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
2035        Expression {
2036            registry: a.registry.clone(),
2037            tree: build(Box::new(a.tree.clone())),
2038        }
2039    }
2040
2041    /// Get the list of parameter names in the order they appear in the underlying resources.
2042    pub fn parameters(&self) -> Vec<String> {
2043        self.registry.resources.parameter_names()
2044    }
2045
2046    /// Get the list of free parameter names.
2047    pub fn free_parameters(&self) -> Vec<String> {
2048        self.registry.resources.free_parameter_names()
2049    }
2050
2051    /// Get the list of fixed parameter names.
2052    pub fn fixed_parameters(&self) -> Vec<String> {
2053        self.registry.resources.fixed_parameter_names()
2054    }
2055
2056    /// Number of free parameters.
2057    pub fn n_free(&self) -> usize {
2058        self.registry.resources.n_free_parameters()
2059    }
2060
2061    /// Number of fixed parameters.
2062    pub fn n_fixed(&self) -> usize {
2063        self.registry.resources.n_fixed_parameters()
2064    }
2065
2066    /// Total number of parameters.
2067    pub fn n_parameters(&self) -> usize {
2068        self.registry.resources.n_parameters()
2069    }
2070
2071    /// Returns a tree-like diagnostic snapshot of this expression's compiled form.
2072    ///
2073    /// This compiles the expression on each call with every registered amplitude active. Use
2074    /// [`Evaluator::compiled_expression`] when you need the compiled form for a loaded evaluator's
2075    /// current active-amplitude mask.
2076    pub fn compiled_expression(&self) -> CompiledExpression {
2077        let active_amplitudes = vec![true; self.registry.amplitudes.len()];
2078        let amplitude_dependencies = self
2079            .registry
2080            .amplitudes
2081            .iter()
2082            .map(|amp| ir::DependenceClass::from(amp.dependence_hint()))
2083            .collect::<Vec<_>>();
2084        let amplitude_realness = self
2085            .registry
2086            .amplitudes
2087            .iter()
2088            .map(|amp| amp.real_valued_hint())
2089            .collect::<Vec<_>>();
2090        let expression_ir = ir::compile_expression_ir_with_real_hints(
2091            &self.tree,
2092            &active_amplitudes,
2093            &amplitude_dependencies,
2094            &amplitude_realness,
2095        );
2096        CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
2097    }
2098
2099    /// Fix a parameter used by this expression's evaluator resources.
2100    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
2101        self.registry.resources.fix_parameter(name, value)
2102    }
2103
2104    /// Mark a parameter used by this expression's evaluator resources as free.
2105    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
2106        self.registry.resources.free_parameter(name)
2107    }
2108
2109    /// Return a new [`Expression`] with a single parameter renamed.
2110    pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
2111        self.registry.resources.rename_parameter(old, new)
2112    }
2113
2114    /// Return a new [`Expression`] with several parameters renamed.
2115    pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
2116        self.registry.resources.rename_parameters(mapping)
2117    }
2118
2119    /// Load an [`Expression`] against a dataset, binding amplitudes and reserving caches.
2120    pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
2121        let mut resources = self.registry.resources.clone();
2122        let metadata = dataset.metadata();
2123        resources.reserve_cache(dataset.n_events_local());
2124        resources.refresh_active_indices();
2125        let parameter_map = resources.parameter_map.clone();
2126        let mut amplitudes: Vec<Box<dyn Amplitude>> = self
2127            .registry
2128            .amplitudes
2129            .iter()
2130            .map(|amp| dyn_clone::clone_box(&**amp))
2131            .collect();
2132        {
2133            for amplitude in amplitudes.iter_mut() {
2134                amplitude.bind(metadata)?;
2135                amplitude.precompute_all(dataset, &mut resources);
2136            }
2137        }
2138        let ir_compile_start = Instant::now();
2139        let expression_ir = {
2140            let mut active_amplitudes = vec![false; amplitudes.len()];
2141            for &index in resources.active_indices() {
2142                active_amplitudes[index] = true;
2143            }
2144            let amplitude_dependencies = amplitudes
2145                .iter()
2146                .map(|amp| ir::DependenceClass::from(amp.dependence_hint()))
2147                .collect::<Vec<_>>();
2148            let amplitude_realness = amplitudes
2149                .iter()
2150                .map(|amp| amp.real_valued_hint())
2151                .collect::<Vec<_>>();
2152            ir::compile_expression_ir_with_real_hints(
2153                &self.tree,
2154                &active_amplitudes,
2155                &amplitude_dependencies,
2156                &amplitude_realness,
2157            )
2158        };
2159        let initial_ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
2160        let cached_integrals_start = Instant::now();
2161        let cached_integrals = Evaluator::precompute_cached_integrals_at_load(
2162            &expression_ir,
2163            &amplitudes,
2164            &resources,
2165            dataset,
2166            parameter_map.free().len(),
2167        )?;
2168        let initial_cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
2169        let lowering_start = Instant::now();
2170        let lowered_artifacts = Arc::new(Evaluator::lower_expression_runtime_artifacts(
2171            &expression_ir,
2172            &cached_integrals,
2173        )?);
2174        let initial_lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
2175        let execution_sets = expression_ir.normalization_execution_sets().clone();
2176        let cached_integral_key =
2177            Evaluator::cached_integral_cache_key(resources.active.clone(), dataset);
2178        let cached_integral_state = Arc::new(CachedIntegralCacheState {
2179            key: cached_integral_key.clone(),
2180            expression_ir,
2181            values: cached_integrals,
2182            execution_sets,
2183        });
2184        let specialization_state = ExpressionSpecializationState {
2185            cached_integrals: cached_integral_state.clone(),
2186            lowered_artifacts: lowered_artifacts.clone(),
2187        };
2188        let specialization_cache = HashMap::from([(cached_integral_key, specialization_state)]);
2189        let lowered_artifact_cache =
2190            HashMap::from([(resources.active.clone(), lowered_artifacts.clone())]);
2191        Ok(Evaluator {
2192            amplitudes,
2193            resources: Arc::new(RwLock::new(resources)),
2194            dataset: dataset.clone(),
2195            expression: self.tree.clone(),
2196            ir_planning: ExpressionIrPlanningState {
2197                cached_integrals: Arc::new(RwLock::new(Some(cached_integral_state))),
2198                specialization_cache: Arc::new(RwLock::new(specialization_cache)),
2199                specialization_metrics: Arc::new(RwLock::new(ExpressionSpecializationMetrics {
2200                    cache_hits: 0,
2201                    cache_misses: 1,
2202                })),
2203                lowered_artifact_cache: Arc::new(RwLock::new(lowered_artifact_cache)),
2204                active_lowered_artifacts: Arc::new(RwLock::new(Some(lowered_artifacts.clone()))),
2205                specialization_status: Arc::new(RwLock::new(Some(
2206                    ExpressionSpecializationStatus {
2207                        origin: ExpressionSpecializationOrigin::InitialLoad,
2208                    },
2209                ))),
2210                compile_metrics: Arc::new(RwLock::new(ExpressionCompileMetrics {
2211                    initial_ir_compile_nanos,
2212                    initial_cached_integrals_nanos,
2213                    initial_lowering_nanos,
2214                    specialization_lowering_cache_misses: 1,
2215                    ..Default::default()
2216                })),
2217            },
2218            registry: self.registry.clone(),
2219        })
2220    }
2221
2222    /// Takes the real part of the given [`Expression`].
2223    pub fn real(&self) -> Self {
2224        Self::unary_op(self, ExpressionNode::Real)
2225    }
2226    /// Takes the imaginary part of the given [`Expression`].
2227    pub fn imag(&self) -> Self {
2228        Self::unary_op(self, ExpressionNode::Imag)
2229    }
2230    /// Takes the complex conjugate of the given [`Expression`].
2231    pub fn conj(&self) -> Self {
2232        Self::unary_op(self, ExpressionNode::Conj)
2233    }
2234    /// Takes the absolute square of the given [`Expression`].
2235    pub fn norm_sqr(&self) -> Self {
2236        Self::unary_op(self, ExpressionNode::NormSqr)
2237    }
2238    /// Takes the square root of the given [`Expression`].
2239    pub fn sqrt(&self) -> Self {
2240        Self::unary_op(self, ExpressionNode::Sqrt)
2241    }
2242    /// Raises the given [`Expression`] to an expression-valued power.
2243    pub fn pow(&self, power: &Expression) -> Self {
2244        Self::binary_op(self, power, ExpressionNode::Pow)
2245    }
2246    /// Raises the given [`Expression`] to an integer power.
2247    pub fn powi(&self, power: i32) -> Self {
2248        Self::unary_op(self, |input| ExpressionNode::PowI(input, power))
2249    }
2250    /// Raises the given [`Expression`] to a real-valued power.
2251    pub fn powf(&self, power: f64) -> Self {
2252        Self::unary_op(self, |input| ExpressionNode::PowF(input, power))
2253    }
2254    /// Takes the exponential of the given [`Expression`].
2255    pub fn exp(&self) -> Self {
2256        Self::unary_op(self, ExpressionNode::Exp)
2257    }
2258    /// Takes the sine of the given [`Expression`].
2259    pub fn sin(&self) -> Self {
2260        Self::unary_op(self, ExpressionNode::Sin)
2261    }
2262    /// Takes the cosine of the given [`Expression`].
2263    pub fn cos(&self) -> Self {
2264        Self::unary_op(self, ExpressionNode::Cos)
2265    }
2266    /// Takes the natural logarithm of the given [`Expression`].
2267    pub fn log(&self) -> Self {
2268        Self::unary_op(self, ExpressionNode::Log)
2269    }
2270    /// Takes the complex phase factor exp(i * expression).
2271    pub fn cis(&self) -> Self {
2272        Self::unary_op(self, ExpressionNode::Cis)
2273    }
2274
2275    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
2276    fn write_tree(
2277        &self,
2278        t: &ExpressionNode,
2279        f: &mut std::fmt::Formatter<'_>,
2280        parent_prefix: &str,
2281        immediate_prefix: &str,
2282        parent_suffix: &str,
2283    ) -> std::fmt::Result {
2284        let display_string = match t {
2285            ExpressionNode::Amp(idx) => {
2286                let name = self
2287                    .registry
2288                    .amplitude_names
2289                    .get(*idx)
2290                    .cloned()
2291                    .unwrap_or_else(|| "<unregistered>".to_string());
2292                format!("{name}(id={idx})")
2293            }
2294            ExpressionNode::Add(_, _) => "+".to_string(),
2295            ExpressionNode::Sub(_, _) => "-".to_string(),
2296            ExpressionNode::Mul(_, _) => "×".to_string(),
2297            ExpressionNode::Div(_, _) => "÷".to_string(),
2298            ExpressionNode::Neg(_) => "-".to_string(),
2299            ExpressionNode::Real(_) => "Re".to_string(),
2300            ExpressionNode::Imag(_) => "Im".to_string(),
2301            ExpressionNode::Conj(_) => "*".to_string(),
2302            ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
2303            ExpressionNode::Zero => "0 (exact)".to_string(),
2304            ExpressionNode::One => "1 (exact)".to_string(),
2305            ExpressionNode::Constant(v) => v.to_string(),
2306            ExpressionNode::ComplexConstant(v) => v.to_string(),
2307            ExpressionNode::Sqrt(_) => "Sqrt".to_string(),
2308            ExpressionNode::Pow(_, _) => "Pow".to_string(),
2309            ExpressionNode::PowI(_, power) => format!("PowI({power})"),
2310            ExpressionNode::PowF(_, power) => format!("PowF({power})"),
2311            ExpressionNode::Exp(_) => "Exp".to_string(),
2312            ExpressionNode::Sin(_) => "Sin".to_string(),
2313            ExpressionNode::Cos(_) => "Cos".to_string(),
2314            ExpressionNode::Log(_) => "Log".to_string(),
2315            ExpressionNode::Cis(_) => "Cis".to_string(),
2316        };
2317        writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
2318        match t {
2319            ExpressionNode::Amp(_)
2320            | ExpressionNode::Zero
2321            | ExpressionNode::One
2322            | ExpressionNode::Constant(_)
2323            | ExpressionNode::ComplexConstant(_) => {}
2324            ExpressionNode::Add(a, b)
2325            | ExpressionNode::Sub(a, b)
2326            | ExpressionNode::Mul(a, b)
2327            | ExpressionNode::Div(a, b)
2328            | ExpressionNode::Pow(a, b) => {
2329                let terms = [a, b];
2330                let mut it = terms.iter().peekable();
2331                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
2332                while let Some(child) = it.next() {
2333                    match it.peek() {
2334                        Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│  "),
2335                        None => self.write_tree(child, f, &child_prefix, "└─ ", "   "),
2336                    }?;
2337                }
2338            }
2339            ExpressionNode::Neg(a)
2340            | ExpressionNode::Real(a)
2341            | ExpressionNode::Imag(a)
2342            | ExpressionNode::Conj(a)
2343            | ExpressionNode::NormSqr(a)
2344            | ExpressionNode::Sqrt(a)
2345            | ExpressionNode::PowI(a, _)
2346            | ExpressionNode::PowF(a, _)
2347            | ExpressionNode::Exp(a)
2348            | ExpressionNode::Sin(a)
2349            | ExpressionNode::Cos(a)
2350            | ExpressionNode::Log(a)
2351            | ExpressionNode::Cis(a) => {
2352                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
2353                self.write_tree(a, f, &child_prefix, "└─ ", "   ")?;
2354            }
2355        }
2356        Ok(())
2357    }
2358}
2359
2360impl Debug for Expression {
2361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2362        self.write_tree(&self.tree, f, "", "", "")
2363    }
2364}
2365
2366impl Display for Expression {
2367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2368        self.write_tree(&self.tree, f, "", "", "")
2369    }
2370}
2371
2372#[rustfmt::skip]
2373impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
2374    Expression::binary_op(a, b, ExpressionNode::Add)
2375});
2376#[rustfmt::skip]
2377impl_op_ex!(+ |a: &Expression, b: &f64| -> Expression {
2378    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
2379});
2380#[rustfmt::skip]
2381impl_op_ex!(+ |a: &f64, b: &Expression| -> Expression {
2382    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
2383});
2384#[rustfmt::skip]
2385impl_op_ex!(+ |a: &Expression, b: &Complex64| -> Expression {
2386    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
2387});
2388#[rustfmt::skip]
2389impl_op_ex!(+ |a: &Complex64, b: &Expression| -> Expression {
2390    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
2391});
2392
2393#[rustfmt::skip]
2394impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
2395    Expression::binary_op(a, b, ExpressionNode::Sub)
2396});
2397#[rustfmt::skip]
2398impl_op_ex!(- |a: &Expression, b: &f64| -> Expression {
2399    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
2400});
2401#[rustfmt::skip]
2402impl_op_ex!(- |a: &f64, b: &Expression| -> Expression {
2403    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
2404});
2405#[rustfmt::skip]
2406impl_op_ex!(- |a: &Expression, b: &Complex64| -> Expression {
2407    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
2408});
2409#[rustfmt::skip]
2410impl_op_ex!(- |a: &Complex64, b: &Expression| -> Expression {
2411    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
2412});
2413
2414#[rustfmt::skip]
2415impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
2416    Expression::binary_op(a, b, ExpressionNode::Mul)
2417});
2418#[rustfmt::skip]
2419impl_op_ex!(* |a: &Expression, b: &f64| -> Expression {
2420    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
2421});
2422#[rustfmt::skip]
2423impl_op_ex!(* |a: &f64, b: &Expression| -> Expression {
2424    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
2425});
2426#[rustfmt::skip]
2427impl_op_ex!(* |a: &Expression, b: &Complex64| -> Expression {
2428    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
2429});
2430#[rustfmt::skip]
2431impl_op_ex!(* |a: &Complex64, b: &Expression| -> Expression {
2432    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
2433});
2434
2435#[rustfmt::skip]
2436impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
2437    Expression::binary_op(a, b, ExpressionNode::Div)
2438});
2439#[rustfmt::skip]
2440impl_op_ex!(/ |a: &Expression, b: &f64| -> Expression {
2441    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
2442});
2443#[rustfmt::skip]
2444impl_op_ex!(/ |a: &f64, b: &Expression| -> Expression {
2445    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
2446});
2447#[rustfmt::skip]
2448impl_op_ex!(/ |a: &Expression, b: &Complex64| -> Expression {
2449    Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
2450});
2451#[rustfmt::skip]
2452impl_op_ex!(/ |a: &Complex64, b: &Expression| -> Expression {
2453    Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
2454});
2455
2456#[rustfmt::skip]
2457impl_op_ex!(- |a: &Expression| -> Expression {
2458    Expression::unary_op(a, ExpressionNode::Neg)
2459});
2460// NOTE: no need to add an impl for negating f64 or complex!
2461
2462#[derive(Clone, Debug)]
2463#[doc(hidden)]
2464pub struct ExpressionValueProgramSnapshot {
2465    lowered_program: lowered::LoweredProgram,
2466}
2467
2468#[derive(Clone, Debug, PartialEq)]
2469/// A node in a compiled expression diagnostic snapshot.
2470pub enum CompiledExpressionNode {
2471    /// A complex constant.
2472    Constant(Complex64),
2473    /// A registered amplitude by index and display name.
2474    Amplitude {
2475        /// The amplitude index used by the compiled evaluator.
2476        index: usize,
2477        /// The registered amplitude name.
2478        name: String,
2479    },
2480    /// A unary operation and its input node.
2481    Unary {
2482        /// The display label for the operation.
2483        op: String,
2484        /// The input node index.
2485        input: usize,
2486    },
2487    /// A binary operation and its input nodes.
2488    Binary {
2489        /// The display label for the operation.
2490        op: String,
2491        /// The left input node index.
2492        left: usize,
2493        /// The right input node index.
2494        right: usize,
2495    },
2496}
2497
2498#[derive(Clone, Debug, PartialEq)]
2499/// Tree-like diagnostic view of the compiled expression DAG.
2500///
2501/// The compiled expression is a directed acyclic graph because common subexpressions can be
2502/// deduplicated during compilation. The display format expands the graph from the root once and
2503/// marks later visits to the same node with `(ref)`.
2504pub struct CompiledExpression {
2505    nodes: Vec<CompiledExpressionNode>,
2506    root: usize,
2507}
2508
2509impl CompiledExpression {
2510    fn from_ir(ir: &ir::ExpressionIR, amplitude_names: &[String]) -> Self {
2511        let nodes = ir
2512            .nodes()
2513            .iter()
2514            .map(|node| match node {
2515                ir::IrNode::Constant(value) => CompiledExpressionNode::Constant(*value),
2516                ir::IrNode::Amp(index) => CompiledExpressionNode::Amplitude {
2517                    index: *index,
2518                    name: amplitude_names
2519                        .get(*index)
2520                        .cloned()
2521                        .unwrap_or_else(|| "<unregistered>".to_string()),
2522                },
2523                ir::IrNode::Unary { op, input } => CompiledExpressionNode::Unary {
2524                    op: compiled_unary_op_label(*op),
2525                    input: *input,
2526                },
2527                ir::IrNode::Binary { op, left, right } => CompiledExpressionNode::Binary {
2528                    op: compiled_binary_op_label(*op),
2529                    left: *left,
2530                    right: *right,
2531                },
2532            })
2533            .collect();
2534        Self {
2535            nodes,
2536            root: ir.root(),
2537        }
2538    }
2539
2540    /// Returns the compiled expression node list in evaluator execution order.
2541    pub fn nodes(&self) -> &[CompiledExpressionNode] {
2542        &self.nodes
2543    }
2544
2545    /// Returns the root node index.
2546    pub fn root(&self) -> usize {
2547        self.root
2548    }
2549
2550    fn node_label(&self, index: usize) -> String {
2551        let Some(node) = self.nodes.get(index) else {
2552            return format!("#{index} <missing>");
2553        };
2554        let label = match node {
2555            CompiledExpressionNode::Constant(value) => format!("const {value}"),
2556            CompiledExpressionNode::Amplitude { index, name } => {
2557                format!("{name}(id={index})")
2558            }
2559            CompiledExpressionNode::Unary { op, .. }
2560            | CompiledExpressionNode::Binary { op, .. } => op.clone(),
2561        };
2562        format!("#{index} {label}")
2563    }
2564
2565    /// Credit to Daniel Janus: <https://blog.danieljanus.pl/2023/07/20/iterating-trees/>
2566    fn write_tree(
2567        &self,
2568        index: usize,
2569        f: &mut std::fmt::Formatter<'_>,
2570        parent_prefix: &str,
2571        immediate_prefix: &str,
2572        parent_suffix: &str,
2573        expanded: &mut [bool],
2574    ) -> std::fmt::Result {
2575        let already_expanded = expanded.get(index).copied().unwrap_or(false);
2576        if let Some(slot) = expanded.get_mut(index) {
2577            *slot = true;
2578        }
2579        let ref_suffix = if already_expanded { " (ref)" } else { "" };
2580        writeln!(
2581            f,
2582            "{}{}{}{}",
2583            parent_prefix,
2584            immediate_prefix,
2585            self.node_label(index),
2586            ref_suffix
2587        )?;
2588        if already_expanded {
2589            return Ok(());
2590        }
2591        let Some(node) = self.nodes.get(index) else {
2592            return Ok(());
2593        };
2594        match node {
2595            CompiledExpressionNode::Constant(_) | CompiledExpressionNode::Amplitude { .. } => {}
2596            CompiledExpressionNode::Unary { input, .. } => {
2597                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
2598                self.write_tree(*input, f, &child_prefix, "└─ ", "   ", expanded)?;
2599            }
2600            CompiledExpressionNode::Binary { left, right, .. } => {
2601                let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
2602                self.write_tree(*left, f, &child_prefix, "├─ ", "│  ", expanded)?;
2603                self.write_tree(*right, f, &child_prefix, "└─ ", "   ", expanded)?;
2604            }
2605        }
2606        Ok(())
2607    }
2608}
2609
2610impl Display for CompiledExpression {
2611    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2612        if self.nodes.is_empty() {
2613            return writeln!(f, "<empty>");
2614        }
2615        let mut expanded = vec![false; self.nodes.len()];
2616        self.write_tree(self.root, f, "", "", "", &mut expanded)
2617    }
2618}
2619
2620fn compiled_unary_op_label(op: ir::IrUnaryOp) -> String {
2621    match op {
2622        ir::IrUnaryOp::Neg => "-".to_string(),
2623        ir::IrUnaryOp::Real => "Re".to_string(),
2624        ir::IrUnaryOp::Imag => "Im".to_string(),
2625        ir::IrUnaryOp::Conj => "*".to_string(),
2626        ir::IrUnaryOp::NormSqr => "NormSqr".to_string(),
2627        ir::IrUnaryOp::Sqrt => "Sqrt".to_string(),
2628        ir::IrUnaryOp::PowI(power) => format!("PowI({power})"),
2629        ir::IrUnaryOp::PowF(bits) => format!("PowF({})", f64::from_bits(bits)),
2630        ir::IrUnaryOp::Exp => "Exp".to_string(),
2631        ir::IrUnaryOp::Sin => "Sin".to_string(),
2632        ir::IrUnaryOp::Cos => "Cos".to_string(),
2633        ir::IrUnaryOp::Log => "Log".to_string(),
2634        ir::IrUnaryOp::Cis => "Cis".to_string(),
2635    }
2636}
2637
2638fn compiled_binary_op_label(op: ir::IrBinaryOp) -> String {
2639    match op {
2640        ir::IrBinaryOp::Add => "+".to_string(),
2641        ir::IrBinaryOp::Sub => "-".to_string(),
2642        ir::IrBinaryOp::Mul => "×".to_string(),
2643        ir::IrBinaryOp::Div => "÷".to_string(),
2644        ir::IrBinaryOp::Pow => "Pow".to_string(),
2645    }
2646}
2647#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2648/// Origin of the currently installed expression specialization.
2649pub enum ExpressionSpecializationOrigin {
2650    /// The specialization installed during evaluator construction.
2651    InitialLoad,
2652    /// The specialization was rebuilt because no cached entry matched the current state.
2653    CacheMissRebuild,
2654    /// The specialization was restored from an existing cache entry.
2655    CacheHitRestore,
2656}
2657#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2658/// Current specialization status for the evaluator runtime.
2659pub struct ExpressionSpecializationStatus {
2660    /// How the active specialization was obtained most recently.
2661    pub origin: ExpressionSpecializationOrigin,
2662}
2663#[derive(Clone, Debug, PartialEq, Eq)]
2664/// Diagnostic snapshot of the active runtime specialization state.
2665pub struct ExpressionRuntimeDiagnostics {
2666    /// Whether IR planning state is present for the evaluator.
2667    pub ir_planning_enabled: bool,
2668    /// Whether a lowered value-only program is available for the active specialization.
2669    pub lowered_value_program_present: bool,
2670    /// Whether a lowered gradient-only program is available for the active specialization.
2671    pub lowered_gradient_program_present: bool,
2672    /// Whether a lowered fused value+gradient program is available for the active specialization.
2673    pub lowered_value_gradient_program_present: bool,
2674    /// Number of cached parameter-factor descriptors in the active specialization.
2675    pub cached_parameter_factor_count: usize,
2676    /// Number of cached parameter factors with lowered runtimes available.
2677    pub lowered_cached_parameter_factor_count: usize,
2678    /// Whether a lowered residual normalization runtime is available.
2679    pub residual_runtime_present: bool,
2680    /// Number of cached specialization entries currently retained.
2681    pub specialization_cache_entries: usize,
2682    /// Number of cached lowered-artifact entries currently retained.
2683    pub lowered_artifact_cache_entries: usize,
2684    /// Origin of the currently installed specialization, when available.
2685    pub specialization_status: Option<ExpressionSpecializationStatus>,
2686}
2687#[derive(Clone)]
2688/// IR-planning state derived from the semantic expression tree plus the current active mask.
2689///
2690/// Invariants:
2691/// - `expression_ir` is never a source of truth; it is always derived from `Evaluator::expression`.
2692/// - `cached_integrals` are specialization-dependent and must be treated as invalid once the
2693///   active mask or dataset identity changes.
2694struct ExpressionIrPlanningState {
2695    cached_integrals: Arc<RwLock<Option<Arc<CachedIntegralCacheState>>>>,
2696    specialization_cache:
2697        Arc<RwLock<HashMap<CachedIntegralCacheKey, ExpressionSpecializationState>>>,
2698    specialization_metrics: Arc<RwLock<ExpressionSpecializationMetrics>>,
2699    lowered_artifact_cache: Arc<RwLock<HashMap<Vec<bool>, Arc<LoweredArtifactCacheState>>>>,
2700    active_lowered_artifacts: Arc<RwLock<Option<Arc<LoweredArtifactCacheState>>>>,
2701    specialization_status: Arc<RwLock<Option<ExpressionSpecializationStatus>>>,
2702    compile_metrics: Arc<RwLock<ExpressionCompileMetrics>>,
2703}
2704/// Evaluator for [`Expression`] that mirrors the existing evaluator behavior.
2705#[allow(missing_docs)]
2706#[derive(Clone)]
2707pub struct Evaluator {
2708    pub amplitudes: Vec<Box<dyn Amplitude>>,
2709    pub resources: Arc<RwLock<Resources>>,
2710    pub dataset: Arc<Dataset>,
2711    pub expression: ExpressionNode,
2712    ir_planning: ExpressionIrPlanningState,
2713    registry: ExpressionRegistry,
2714}
2715
2716#[allow(missing_docs)]
2717impl Evaluator {
2718    /// Internal benchmarking/debug counters for specialization cache reuse.
2719    pub fn expression_specialization_metrics(&self) -> ExpressionSpecializationMetrics {
2720        *self.ir_planning.specialization_metrics.read()
2721    }
2722    /// Reset specialization cache counters while leaving cached specializations intact.
2723    pub fn reset_expression_specialization_metrics(&self) {
2724        *self.ir_planning.specialization_metrics.write() =
2725            ExpressionSpecializationMetrics::default();
2726    }
2727    /// Internal benchmarking/debug metrics for staged IR compile and lowering costs.
2728    pub fn expression_compile_metrics(&self) -> ExpressionCompileMetrics {
2729        *self.ir_planning.compile_metrics.read()
2730    }
2731    /// Internal diagnostics surface for active runtime specialization state.
2732    pub fn expression_runtime_diagnostics(&self) -> ExpressionRuntimeDiagnostics {
2733        let active_artifacts = self.active_lowered_artifacts();
2734        let cached_parameter_factor_count = self
2735            .ir_planning
2736            .cached_integrals
2737            .read()
2738            .as_ref()
2739            .map(|state| state.values.len())
2740            .unwrap_or(0);
2741        let lowered_cached_parameter_factor_count = active_artifacts
2742            .as_ref()
2743            .map(|artifacts| {
2744                artifacts
2745                    .lowered_parameter_factors
2746                    .iter()
2747                    .filter(|factor| factor.is_some())
2748                    .count()
2749            })
2750            .unwrap_or(0);
2751        let residual_runtime_present = active_artifacts
2752            .as_ref()
2753            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2754            .is_some();
2755        ExpressionRuntimeDiagnostics {
2756            ir_planning_enabled: true,
2757            lowered_value_program_present: true,
2758            lowered_gradient_program_present: true,
2759            lowered_value_gradient_program_present: true,
2760            cached_parameter_factor_count,
2761            lowered_cached_parameter_factor_count,
2762            residual_runtime_present,
2763            specialization_cache_entries: self.ir_planning.specialization_cache.read().len(),
2764            lowered_artifact_cache_entries: self.ir_planning.lowered_artifact_cache.read().len(),
2765            specialization_status: *self.ir_planning.specialization_status.read(),
2766        }
2767    }
2768    /// Reset post-load compile/lowering counters while preserving initial-load metrics.
2769    pub fn reset_expression_compile_metrics(&self) {
2770        let mut metrics = self.ir_planning.compile_metrics.write();
2771        metrics.specialization_cache_hits = 0;
2772        metrics.specialization_cache_misses = 0;
2773        metrics.specialization_ir_compile_nanos = 0;
2774        metrics.specialization_cached_integrals_nanos = 0;
2775        metrics.specialization_lowering_nanos = 0;
2776        metrics.specialization_lowering_cache_hits = 0;
2777        metrics.specialization_lowering_cache_misses = 0;
2778        metrics.specialization_cache_restore_nanos = 0;
2779    }
2780    #[cfg(test)]
2781    fn expression_ir(&self) -> ir::ExpressionIR {
2782        self.ir_planning
2783            .cached_integrals
2784            .read()
2785            .as_ref()
2786            .map(|state| state.expression_ir.clone())
2787            .expect("cached integral state should exist for evaluator IR access")
2788    }
2789    fn lowered_runtime(&self) -> lowered::LoweredExpressionRuntime {
2790        self.active_lowered_artifacts()
2791            .expect("active lowered artifacts should exist for the current specialization")
2792            .lowered_runtime
2793            .clone()
2794    }
2795    fn active_lowered_artifacts(&self) -> Option<Arc<LoweredArtifactCacheState>> {
2796        self.ir_planning.active_lowered_artifacts.read().clone()
2797    }
2798    fn lowered_runtime_slot_count(&self) -> usize {
2799        let runtime = self.lowered_runtime();
2800        [
2801            runtime.value_program().scratch_slots(),
2802            runtime.gradient_program().scratch_slots(),
2803            runtime.value_gradient_program().scratch_slots(),
2804        ]
2805        .into_iter()
2806        .max()
2807        .unwrap_or(0)
2808    }
2809    fn lowered_value_runtime_slot_count(&self) -> usize {
2810        self.lowered_runtime().value_program().scratch_slots()
2811    }
2812
2813    #[doc(hidden)]
2814    pub fn expression_value_program_snapshot(&self) -> ExpressionValueProgramSnapshot {
2815        ExpressionValueProgramSnapshot {
2816            lowered_program: self.lowered_runtime().value_program().clone(),
2817        }
2818    }
2819
2820    #[doc(hidden)]
2821    pub fn expression_value_program_snapshot_for_active_mask(
2822        &self,
2823        active_mask: &[bool],
2824    ) -> LadduResult<ExpressionValueProgramSnapshot> {
2825        let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
2826        let lowered_program =
2827            lowered::LoweredProgram::from_ir_value_only(&expression_ir).map_err(|error| {
2828                LadduError::Custom(format!(
2829                    "Failed to lower value-only active-mask runtime: {error:?}"
2830                ))
2831            })?;
2832        Ok(ExpressionValueProgramSnapshot { lowered_program })
2833    }
2834
2835    #[doc(hidden)]
2836    pub fn expression_value_program_snapshot_slot_count(
2837        &self,
2838        snapshot: &ExpressionValueProgramSnapshot,
2839    ) -> usize {
2840        let _ = self;
2841        snapshot.lowered_program.scratch_slots()
2842    }
2843
2844    /// Returns a tree-like diagnostic snapshot of the compiled expression for the evaluator's
2845    /// current active-amplitude mask.
2846    pub fn compiled_expression(&self) -> CompiledExpression {
2847        let expression_ir = self.compile_expression_ir_for_active_mask(&self.active_mask());
2848        CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
2849    }
2850
2851    /// Returns the expression represented by this evaluator.
2852    pub fn expression(&self) -> Expression {
2853        Expression {
2854            tree: self.expression.clone(),
2855            registry: self.registry.clone(),
2856        }
2857    }
2858    fn lowered_gradient_runtime_slot_count(&self) -> usize {
2859        self.lowered_runtime().gradient_program().scratch_slots()
2860    }
2861    fn lowered_value_gradient_runtime_slot_count(&self) -> usize {
2862        self.lowered_runtime()
2863            .value_gradient_program()
2864            .scratch_slots()
2865    }
2866
2867    fn expression_value_slot_count(&self) -> usize {
2868        self.lowered_value_runtime_slot_count()
2869    }
2870    fn expression_gradient_slot_count(&self) -> usize {
2871        self.lowered_gradient_runtime_slot_count()
2872    }
2873    fn expression_value_gradient_slot_count(&self) -> usize {
2874        self.lowered_value_gradient_runtime_slot_count()
2875    }
2876
2877    #[doc(hidden)]
2878    pub fn expression_value_gradient_slot_count_public(&self) -> usize {
2879        self.expression_value_gradient_slot_count()
2880    }
2881    #[cfg(test)]
2882    fn specialization_cache_len(&self) -> usize {
2883        self.ir_planning.specialization_cache.read().len()
2884    }
2885    #[cfg(test)]
2886    fn lowered_artifact_cache_len(&self) -> usize {
2887        self.ir_planning.lowered_artifact_cache.read().len()
2888    }
2889    fn install_expression_specialization(&self, specialization: &ExpressionSpecializationState) {
2890        debug_assert!(Self::lowered_artifact_signature_matches(
2891            &specialization.lowered_artifacts,
2892            &specialization.cached_integrals.values,
2893        ));
2894        *self.ir_planning.cached_integrals.write() = Some(specialization.cached_integrals.clone());
2895        *self.ir_planning.active_lowered_artifacts.write() =
2896            Some(specialization.lowered_artifacts.clone());
2897        debug_assert_eq!(
2898            self.active_lowered_artifacts()
2899                .as_ref()
2900                .map(|artifacts| Arc::ptr_eq(artifacts, &specialization.lowered_artifacts)),
2901            Some(true)
2902        );
2903        debug_assert_eq!(
2904            self.lowered_runtime().value_program().scratch_slots(),
2905            specialization
2906                .lowered_artifacts
2907                .lowered_runtime
2908                .value_program()
2909                .scratch_slots()
2910        );
2911    }
2912    fn lower_expression_runtime_artifacts(
2913        expression_ir: &ir::ExpressionIR,
2914        values: &[PrecomputedCachedIntegral],
2915    ) -> LadduResult<LoweredArtifactCacheState> {
2916        let parameter_node_indices = values
2917            .iter()
2918            .map(|value| value.parameter_node_index)
2919            .collect();
2920        let mul_node_indices = values.iter().map(|value| value.mul_node_index).collect();
2921        let lowered_parameter_factors = Self::lower_cached_parameter_factors(expression_ir);
2922        let residual_runtime = Self::lower_residual_runtime(expression_ir, values);
2923        let lowered_runtime = lowered::LoweredExpressionRuntime::from_ir_value_gradient(
2924            expression_ir,
2925        )
2926        .map_err(|error| {
2927            LadduError::Custom(format!(
2928                "Failed to lower expression runtime for specialized IR: {error:?}"
2929            ))
2930        })?;
2931        Ok(LoweredArtifactCacheState {
2932            parameter_node_indices,
2933            mul_node_indices,
2934            lowered_parameter_factors,
2935            residual_runtime,
2936            lowered_runtime,
2937        })
2938    }
2939    fn lowered_artifact_signature_matches(
2940        artifacts: &LoweredArtifactCacheState,
2941        values: &[PrecomputedCachedIntegral],
2942    ) -> bool {
2943        artifacts.parameter_node_indices.len() == values.len()
2944            && artifacts.mul_node_indices.len() == values.len()
2945            && artifacts
2946                .parameter_node_indices
2947                .iter()
2948                .copied()
2949                .eq(values.iter().map(|value| value.parameter_node_index))
2950            && artifacts
2951                .mul_node_indices
2952                .iter()
2953                .copied()
2954                .eq(values.iter().map(|value| value.mul_node_index))
2955    }
2956    fn build_expression_specialization(
2957        &self,
2958        resources: &Resources,
2959        key: CachedIntegralCacheKey,
2960    ) -> LadduResult<ExpressionSpecializationState> {
2961        let ir_compile_start = Instant::now();
2962        let expression_ir = self.compile_expression_ir_for_active_mask(&resources.active);
2963        let ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
2964        let cached_integrals_start = Instant::now();
2965        let values = Self::precompute_cached_integrals_at_load(
2966            &expression_ir,
2967            &self.amplitudes,
2968            resources,
2969            &self.dataset,
2970            self.resources.read().n_free_parameters(),
2971        )?;
2972        let cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
2973        let execution_sets = expression_ir.normalization_execution_sets().clone();
2974        let active_mask_key = resources.active.clone();
2975        let cached_lowered_artifacts = {
2976            let lowered_artifact_cache = self.ir_planning.lowered_artifact_cache.read();
2977            lowered_artifact_cache
2978                .get(&active_mask_key)
2979                .cloned()
2980                .filter(|artifacts| Self::lowered_artifact_signature_matches(artifacts, &values))
2981        };
2982        let lowered_artifacts = if let Some(artifacts) = cached_lowered_artifacts {
2983            self.ir_planning
2984                .compile_metrics
2985                .write()
2986                .specialization_lowering_cache_hits += 1;
2987            artifacts
2988        } else {
2989            let lowering_start = Instant::now();
2990            let artifacts = Arc::new(
2991                Self::lower_expression_runtime_artifacts(&expression_ir, &values)
2992                    .expect("specialized lowered runtime should build"),
2993            );
2994            let lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
2995            self.ir_planning
2996                .lowered_artifact_cache
2997                .write()
2998                .insert(active_mask_key, artifacts.clone());
2999            let mut compile_metrics = self.ir_planning.compile_metrics.write();
3000            compile_metrics.specialization_lowering_cache_misses += 1;
3001            compile_metrics.specialization_lowering_nanos += lowering_nanos;
3002            artifacts
3003        };
3004        let mut compile_metrics = self.ir_planning.compile_metrics.write();
3005        compile_metrics.specialization_cache_misses += 1;
3006        compile_metrics.specialization_ir_compile_nanos += ir_compile_nanos;
3007        compile_metrics.specialization_cached_integrals_nanos += cached_integrals_nanos;
3008        Ok(ExpressionSpecializationState {
3009            cached_integrals: Arc::new(CachedIntegralCacheState {
3010                key,
3011                expression_ir,
3012                values,
3013                execution_sets,
3014            }),
3015            lowered_artifacts,
3016        })
3017    }
3018    fn ensure_expression_specialization(
3019        &self,
3020        resources: &Resources,
3021    ) -> LadduResult<ExpressionSpecializationState> {
3022        let key = Self::cached_integral_cache_key(resources.active.clone(), &self.dataset);
3023        if let Some(state) = self.ir_planning.cached_integrals.read().as_ref() {
3024            if state.key == key {
3025                return Ok(ExpressionSpecializationState {
3026                    cached_integrals: state.clone(),
3027                    lowered_artifacts: self
3028                        .active_lowered_artifacts()
3029                        .expect("active lowered artifacts should exist for cached specialization"),
3030                });
3031            }
3032        }
3033        let cached_specialization = {
3034            let specialization_cache = self.ir_planning.specialization_cache.read();
3035            specialization_cache.get(&key).cloned()
3036        };
3037        if let Some(specialization) = cached_specialization {
3038            let restore_start = Instant::now();
3039            self.ir_planning.specialization_metrics.write().cache_hits += 1;
3040            self.install_expression_specialization(&specialization);
3041            *self.ir_planning.specialization_status.write() =
3042                Some(ExpressionSpecializationStatus {
3043                    origin: ExpressionSpecializationOrigin::CacheHitRestore,
3044                });
3045            let restore_nanos = restore_start.elapsed().as_nanos() as u64;
3046            let mut compile_metrics = self.ir_planning.compile_metrics.write();
3047            compile_metrics.specialization_cache_hits += 1;
3048            compile_metrics.specialization_cache_restore_nanos += restore_nanos;
3049            return Ok(specialization);
3050        }
3051        let specialization = self.build_expression_specialization(resources, key.clone())?;
3052        self.ir_planning.specialization_metrics.write().cache_misses += 1;
3053        self.ir_planning
3054            .specialization_cache
3055            .write()
3056            .insert(key, specialization.clone());
3057        self.install_expression_specialization(&specialization);
3058        let origin = if self.ir_planning.specialization_cache.read().len() == 1 {
3059            ExpressionSpecializationOrigin::InitialLoad
3060        } else {
3061            ExpressionSpecializationOrigin::CacheMissRebuild
3062        };
3063        *self.ir_planning.specialization_status.write() =
3064            Some(ExpressionSpecializationStatus { origin });
3065        Ok(specialization)
3066    }
3067    fn rebuild_runtime_specializations(&self, resources: &Resources) {
3068        let _ = self.ensure_expression_specialization(resources);
3069    }
3070    fn refresh_runtime_specializations(&self) {
3071        let resources = self.resources.read();
3072        self.rebuild_runtime_specializations(&resources);
3073    }
3074    fn cached_integral_cache_key(
3075        active_mask: Vec<bool>,
3076        dataset: &Dataset,
3077    ) -> CachedIntegralCacheKey {
3078        CachedIntegralCacheKey {
3079            active_mask,
3080            n_events_local: dataset.n_events_local(),
3081            events_local_len: dataset.events_local().len(),
3082            weighted_sum_bits: dataset.n_events_weighted_local().to_bits(),
3083            events_ptr: dataset.events_local().as_ptr() as usize,
3084        }
3085    }
3086    fn precompute_cached_integrals_at_load(
3087        expression_ir: &ir::ExpressionIR,
3088        amplitudes: &[Box<dyn Amplitude>],
3089        resources: &Resources,
3090        dataset: &Dataset,
3091        n_free_parameters: usize,
3092    ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
3093        let descriptors = expression_ir.cached_integral_descriptors();
3094        if descriptors.is_empty() {
3095            return Ok(Vec::new());
3096        }
3097        let execution_sets = expression_ir.normalization_execution_sets();
3098        let seed_parameters = vec![0.0; n_free_parameters];
3099        let parameters = resources.parameter_map.assemble(&seed_parameters)?;
3100        let mut amplitude_values = vec![Complex64::ZERO; amplitudes.len()];
3101        let mut value_slots = vec![Complex64::ZERO; expression_ir.node_count()];
3102        let active_set = resources.active_indices();
3103        let cache_active_indices = execution_sets
3104            .cached_cache_amplitudes
3105            .iter()
3106            .copied()
3107            .filter(|index| active_set.binary_search(index).is_ok())
3108            .collect::<Vec<_>>();
3109        let mut weighted_cache_sums = vec![Complex64::ZERO; descriptors.len()];
3110        for (cache, event) in resources.caches.iter().zip(dataset.events_local().iter()) {
3111            amplitude_values.fill(Complex64::ZERO);
3112            for &amp_idx in &cache_active_indices {
3113                amplitude_values[amp_idx] = amplitudes[amp_idx].compute(&parameters, cache);
3114            }
3115            expression_ir.evaluate_into(&amplitude_values, &mut value_slots);
3116            let weight = event.weight();
3117            for (descriptor_index, descriptor) in descriptors.iter().enumerate() {
3118                weighted_cache_sums[descriptor_index] +=
3119                    value_slots[descriptor.cache_node_index] * weight;
3120            }
3121        }
3122        Ok(descriptors
3123            .iter()
3124            .zip(weighted_cache_sums)
3125            .map(
3126                |(descriptor, weighted_cache_sum)| PrecomputedCachedIntegral {
3127                    mul_node_index: descriptor.mul_node_index,
3128                    parameter_node_index: descriptor.parameter_node_index,
3129                    cache_node_index: descriptor.cache_node_index,
3130                    coefficient: descriptor.coefficient,
3131                    weighted_cache_sum,
3132                },
3133            )
3134            .collect())
3135    }
3136    fn lower_cached_parameter_factors(
3137        expression_ir: &ir::ExpressionIR,
3138    ) -> Vec<Option<lowered::LoweredFactorRuntime>> {
3139        expression_ir
3140            .cached_integral_descriptors()
3141            .iter()
3142            .map(|descriptor| {
3143                lowered::LoweredFactorRuntime::from_ir_root_value_gradient(
3144                    expression_ir,
3145                    descriptor.parameter_node_index,
3146                )
3147                .ok()
3148            })
3149            .collect()
3150    }
3151    fn lower_residual_runtime(
3152        expression_ir: &ir::ExpressionIR,
3153        descriptors: &[PrecomputedCachedIntegral],
3154    ) -> Option<lowered::LoweredExpressionRuntime> {
3155        let mut zeroed_nodes = vec![false; expression_ir.node_count()];
3156        for descriptor in descriptors {
3157            if descriptor.mul_node_index < zeroed_nodes.len() {
3158                zeroed_nodes[descriptor.mul_node_index] = true;
3159            }
3160        }
3161        lowered::LoweredExpressionRuntime::from_ir_zeroed_value_gradient(
3162            expression_ir,
3163            &zeroed_nodes,
3164        )
3165        .ok()
3166    }
3167
3168    #[inline]
3169    fn fill_amplitude_values(
3170        &self,
3171        amplitude_values: &mut [Complex64],
3172        active_indices: &[usize],
3173        parameters: &Parameters,
3174        cache: &Cache,
3175    ) {
3176        amplitude_values.fill(Complex64::ZERO);
3177        for &amp_idx in active_indices {
3178            amplitude_values[amp_idx] = self.amplitudes[amp_idx].compute(parameters, cache);
3179        }
3180    }
3181
3182    #[inline]
3183    fn fill_amplitude_gradients(
3184        &self,
3185        gradient_values: &mut [DVector<Complex64>],
3186        active_mask: &[bool],
3187        parameters: &Parameters,
3188        cache: &Cache,
3189    ) {
3190        for ((amp, active), grad) in self
3191            .amplitudes
3192            .iter()
3193            .zip(active_mask.iter())
3194            .zip(gradient_values.iter_mut())
3195        {
3196            grad.fill(Complex64::ZERO);
3197            if *active {
3198                amp.compute_gradient(parameters, cache, grad);
3199            }
3200        }
3201    }
3202
3203    #[inline]
3204    fn fill_amplitude_values_and_gradients(
3205        &self,
3206        amplitude_values: &mut [Complex64],
3207        gradient_values: &mut [DVector<Complex64>],
3208        active_indices: &[usize],
3209        active_mask: &[bool],
3210        parameters: &Parameters,
3211        cache: &Cache,
3212    ) {
3213        self.fill_amplitude_values(amplitude_values, active_indices, parameters, cache);
3214        self.fill_amplitude_gradients(gradient_values, active_mask, parameters, cache);
3215    }
3216
3217    #[doc(hidden)]
3218    pub fn fill_amplitude_values_and_gradients_public(
3219        &self,
3220        amplitude_values: &mut [Complex64],
3221        gradient_values: &mut [DVector<Complex64>],
3222        active_indices: &[usize],
3223        active_mask: &[bool],
3224        parameters: &Parameters,
3225        cache: &Cache,
3226    ) {
3227        self.fill_amplitude_values_and_gradients(
3228            amplitude_values,
3229            gradient_values,
3230            active_indices,
3231            active_mask,
3232            parameters,
3233            cache,
3234        );
3235    }
3236
3237    #[cfg(feature = "execution-context-prototype")]
3238    #[inline]
3239    fn evaluate_cache_gradient_with_scratch(
3240        &self,
3241        amplitude_values: &mut [Complex64],
3242        gradient_values: &mut [DVector<Complex64>],
3243        value_slots: &mut [Complex64],
3244        gradient_slots: &mut [DVector<Complex64>],
3245        active_indices: &[usize],
3246        active_mask: &[bool],
3247        parameters: &Parameters,
3248        cache: &Cache,
3249    ) -> DVector<Complex64> {
3250        self.fill_amplitude_values_and_gradients(
3251            amplitude_values,
3252            gradient_values,
3253            active_indices,
3254            active_mask,
3255            parameters,
3256            cache,
3257        );
3258        self.evaluate_expression_gradient_with_scratch(
3259            amplitude_values,
3260            gradient_values,
3261            value_slots,
3262            gradient_slots,
3263        )
3264    }
3265
3266    #[cfg(feature = "execution-context-prototype")]
3267    #[allow(dead_code)]
3268    #[inline]
3269    fn evaluate_cache_value_gradient_with_scratch(
3270        &self,
3271        amplitude_values: &mut [Complex64],
3272        gradient_values: &mut [DVector<Complex64>],
3273        value_slots: &mut [Complex64],
3274        gradient_slots: &mut [DVector<Complex64>],
3275        active_indices: &[usize],
3276        active_mask: &[bool],
3277        parameters: &Parameters,
3278        cache: &Cache,
3279    ) -> (Complex64, DVector<Complex64>) {
3280        self.fill_amplitude_values_and_gradients(
3281            amplitude_values,
3282            gradient_values,
3283            active_indices,
3284            active_mask,
3285            parameters,
3286            cache,
3287        );
3288        self.evaluate_expression_value_gradient_with_scratch(
3289            amplitude_values,
3290            gradient_values,
3291            value_slots,
3292            gradient_slots,
3293        )
3294    }
3295
3296    pub fn expression_slot_count(&self) -> usize {
3297        self.lowered_runtime_slot_count()
3298    }
3299    fn compile_expression_ir_for_active_mask(&self, active_mask: &[bool]) -> ir::ExpressionIR {
3300        let amplitude_dependencies = self
3301            .amplitudes
3302            .iter()
3303            .map(|amp| ir::DependenceClass::from(amp.dependence_hint()))
3304            .collect::<Vec<_>>();
3305        let amplitude_realness = self
3306            .amplitudes
3307            .iter()
3308            .map(|amp| amp.real_valued_hint())
3309            .collect::<Vec<_>>();
3310        ir::compile_expression_ir_with_real_hints(
3311            &self.expression,
3312            active_mask,
3313            &amplitude_dependencies,
3314            &amplitude_realness,
3315        )
3316    }
3317    fn lower_expression_runtime_for_active_mask(
3318        &self,
3319        active_mask: &[bool],
3320    ) -> LadduResult<lowered::LoweredExpressionRuntime> {
3321        let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
3322        lowered::LoweredExpressionRuntime::from_ir_value_gradient(&expression_ir).map_err(|error| {
3323            LadduError::Custom(format!(
3324                "Failed to lower active-mask runtime specialization: {error:?}"
3325            ))
3326        })
3327    }
3328    fn ensure_cached_integral_cache_state(
3329        &self,
3330        resources: &Resources,
3331    ) -> LadduResult<Arc<CachedIntegralCacheState>> {
3332        Ok(self
3333            .ensure_expression_specialization(resources)?
3334            .cached_integrals)
3335    }
3336
3337    fn evaluate_expression_runtime_value_with_scratch(
3338        &self,
3339        amplitude_values: &[Complex64],
3340        scratch: &mut [Complex64],
3341    ) -> Complex64 {
3342        let lowered_runtime = self.lowered_runtime();
3343        lowered_runtime
3344            .value_program()
3345            .evaluate_into(amplitude_values, scratch)
3346    }
3347
3348    #[doc(hidden)]
3349    pub fn evaluate_expression_value_with_program_snapshot(
3350        &self,
3351        program_snapshot: &ExpressionValueProgramSnapshot,
3352        amplitude_values: &[Complex64],
3353        scratch: &mut [Complex64],
3354    ) -> Complex64 {
3355        program_snapshot
3356            .lowered_program
3357            .evaluate_into(amplitude_values, scratch)
3358    }
3359
3360    fn evaluate_expression_runtime_gradient_with_scratch(
3361        &self,
3362        amplitude_values: &[Complex64],
3363        gradient_values: &[DVector<Complex64>],
3364        value_scratch: &mut [Complex64],
3365        gradient_scratch: &mut [DVector<Complex64>],
3366    ) -> DVector<Complex64> {
3367        let lowered_runtime = self.lowered_runtime();
3368        lowered_runtime.gradient_program().evaluate_gradient_into(
3369            amplitude_values,
3370            gradient_values,
3371            value_scratch,
3372            gradient_scratch,
3373        )
3374    }
3375
3376    fn evaluate_expression_runtime_value_gradient_with_scratch(
3377        &self,
3378        amplitude_values: &[Complex64],
3379        gradient_values: &[DVector<Complex64>],
3380        value_scratch: &mut [Complex64],
3381        gradient_scratch: &mut [DVector<Complex64>],
3382    ) -> (Complex64, DVector<Complex64>) {
3383        let lowered_runtime = self.lowered_runtime();
3384        lowered_runtime
3385            .value_gradient_program()
3386            .evaluate_value_gradient_into(
3387                amplitude_values,
3388                gradient_values,
3389                value_scratch,
3390                gradient_scratch,
3391            )
3392    }
3393
3394    fn evaluate_expression_runtime_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
3395        let lowered_runtime = self.lowered_runtime();
3396        let program = lowered_runtime.value_program();
3397        let mut scratch = vec![Complex64::ZERO; program.scratch_slots()];
3398        program.evaluate_into(amplitude_values, &mut scratch)
3399    }
3400
3401    fn evaluate_expression_runtime_gradient(
3402        &self,
3403        amplitude_values: &[Complex64],
3404        gradient_values: &[DVector<Complex64>],
3405    ) -> DVector<Complex64> {
3406        let lowered_runtime = self.lowered_runtime();
3407        let program = lowered_runtime.gradient_program();
3408        let mut value_scratch = vec![Complex64::ZERO; program.scratch_slots()];
3409        let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
3410        let mut gradient_scratch = vec![Complex64::ZERO; program.scratch_slots() * grad_dim];
3411        program.evaluate_gradient_into_flat(
3412            amplitude_values,
3413            gradient_values,
3414            &mut value_scratch,
3415            &mut gradient_scratch,
3416            grad_dim,
3417        )
3418    }
3419    /// Dependence classification for the compiled expression root.
3420    pub fn expression_root_dependence(&self) -> LadduResult<ExpressionDependence> {
3421        let resources = self.resources.read();
3422        Ok(self
3423            .ensure_cached_integral_cache_state(&resources)?
3424            .expression_ir
3425            .root_dependence()
3426            .into())
3427    }
3428    /// Dependence classification for each compiled expression node.
3429    pub fn expression_node_dependence_annotations(&self) -> LadduResult<Vec<ExpressionDependence>> {
3430        let resources = self.resources.read();
3431        Ok(self
3432            .ensure_cached_integral_cache_state(&resources)?
3433            .expression_ir
3434            .node_dependence_annotations()
3435            .iter()
3436            .copied()
3437            .map(Into::into)
3438            .collect())
3439    }
3440    /// Warning-level diagnostics for potentially inconsistent dependence hints.
3441    pub fn expression_dependence_warnings(&self) -> LadduResult<Vec<String>> {
3442        let resources = self.resources.read();
3443        Ok(self
3444            .ensure_cached_integral_cache_state(&resources)?
3445            .expression_ir
3446            .dependence_warnings()
3447            .to_vec())
3448    }
3449    /// Explain/debug view of IR normalization planning decomposition.
3450    pub fn expression_normalization_plan_explain(&self) -> LadduResult<NormalizationPlanExplain> {
3451        let resources = self.resources.read();
3452        Ok(self
3453            .ensure_cached_integral_cache_state(&resources)?
3454            .expression_ir
3455            .normalization_plan_explain()
3456            .into())
3457    }
3458    /// Explain/debug view of amplitude execution sets used by normalization evaluation.
3459    pub fn expression_normalization_execution_sets(
3460        &self,
3461    ) -> LadduResult<NormalizationExecutionSetsExplain> {
3462        let resources = self.resources.read();
3463        Ok(self
3464            .ensure_cached_integral_cache_state(&resources)?
3465            .execution_sets
3466            .clone()
3467            .into())
3468    }
3469    /// Cached integral terms precomputed at evaluator load.
3470    pub fn expression_precomputed_cached_integrals(
3471        &self,
3472    ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
3473        let resources = self.resources.read();
3474        Ok(self
3475            .ensure_cached_integral_cache_state(&resources)?
3476            .values
3477            .clone())
3478    }
3479    /// Derivative rules for cached separable terms evaluated at the given parameter point.
3480    ///
3481    /// Each returned term corresponds to a cached separable descriptor and contributes
3482    /// `weighted_gradient` to `d(normalization)/dp` prior to residual-term combination.
3483    pub fn expression_precomputed_cached_integral_gradient_terms(
3484        &self,
3485        parameters: &[f64],
3486    ) -> LadduResult<Vec<PrecomputedCachedIntegralGradientTerm>> {
3487        let resources = self.resources.read();
3488        let state = self.ensure_cached_integral_cache_state(&resources)?;
3489        if state.values.is_empty() {
3490            return Ok(Vec::new());
3491        }
3492
3493        let Some(cache) = resources.caches.first() else {
3494            return Ok(state
3495                .values
3496                .iter()
3497                .map(|descriptor| PrecomputedCachedIntegralGradientTerm {
3498                    mul_node_index: descriptor.mul_node_index,
3499                    parameter_node_index: descriptor.parameter_node_index,
3500                    cache_node_index: descriptor.cache_node_index,
3501                    coefficient: descriptor.coefficient,
3502                    weighted_gradient: DVector::zeros(parameters.len()),
3503                })
3504                .collect());
3505        };
3506
3507        let parameter_values = resources.parameter_map.assemble(parameters)?;
3508        let mut amplitude_values = vec![Complex64::ZERO; self.amplitudes.len()];
3509        self.fill_amplitude_values(
3510            &mut amplitude_values,
3511            resources.active_indices(),
3512            &parameter_values,
3513            cache,
3514        );
3515        let mut amplitude_gradients = (0..self.amplitudes.len())
3516            .map(|_| DVector::zeros(parameters.len()))
3517            .collect::<Vec<_>>();
3518        self.fill_amplitude_gradients(
3519            &mut amplitude_gradients,
3520            &resources.active,
3521            &parameter_values,
3522            cache,
3523        );
3524        let lowered_artifacts = self.active_lowered_artifacts();
3525        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3526        let mut gradient_slots = (0..state.expression_ir.node_count())
3527            .map(|_| DVector::zeros(parameters.len()))
3528            .collect::<Vec<_>>();
3529        let max_lowered_slots = lowered_artifacts
3530            .as_ref()
3531            .map(|artifacts| {
3532                artifacts
3533                    .lowered_parameter_factors
3534                    .iter()
3535                    .filter_map(|runtime| {
3536                        runtime
3537                            .as_ref()
3538                            .and_then(|runtime| runtime.gradient_program())
3539                            .map(|program| program.scratch_slots())
3540                    })
3541                    .max()
3542                    .unwrap_or(0)
3543            })
3544            .unwrap_or(0);
3545        let mut lowered_value_slots = vec![Complex64::ZERO; max_lowered_slots];
3546        let mut lowered_gradient_slots = vec![DVector::zeros(parameters.len()); max_lowered_slots];
3547        let use_lowered = lowered_artifacts.as_ref().is_some_and(|artifacts| {
3548            artifacts.lowered_parameter_factors.len() == state.values.len()
3549                && artifacts.lowered_parameter_factors.iter().all(|runtime| {
3550                    runtime
3551                        .as_ref()
3552                        .and_then(|runtime| runtime.gradient_program())
3553                        .is_some()
3554                })
3555        });
3556
3557        if !use_lowered {
3558            let _ = state.expression_ir.evaluate_gradient_into(
3559                &amplitude_values,
3560                &amplitude_gradients,
3561                &mut value_slots,
3562                &mut gradient_slots,
3563            );
3564        }
3565
3566        if use_lowered {
3567            let lowered_artifacts = lowered_artifacts.expect("lowered artifacts should exist");
3568            Ok(state
3569                .values
3570                .iter()
3571                .cloned()
3572                .zip(lowered_artifacts.lowered_parameter_factors.iter())
3573                .map(|(descriptor, runtime)| {
3574                    let parameter_gradient = runtime
3575                        .as_ref()
3576                        .and_then(|runtime| runtime.gradient_program())
3577                        .map(|program| {
3578                            program.evaluate_gradient_into(
3579                                &amplitude_values,
3580                                &amplitude_gradients,
3581                                &mut lowered_value_slots[..program.scratch_slots()],
3582                                &mut lowered_gradient_slots[..program.scratch_slots()],
3583                            )
3584                        })
3585                        .unwrap_or_else(|| DVector::zeros(parameters.len()));
3586                    let weighted_gradient = parameter_gradient.map(|value| {
3587                        value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
3588                    });
3589                    PrecomputedCachedIntegralGradientTerm {
3590                        mul_node_index: descriptor.mul_node_index,
3591                        parameter_node_index: descriptor.parameter_node_index,
3592                        cache_node_index: descriptor.cache_node_index,
3593                        coefficient: descriptor.coefficient,
3594                        weighted_gradient,
3595                    }
3596                })
3597                .collect())
3598        } else {
3599            Ok(state
3600                .values
3601                .iter()
3602                .map(|descriptor| {
3603                    let parameter_gradient = gradient_slots
3604                        .get(descriptor.parameter_node_index)
3605                        .cloned()
3606                        .unwrap_or_else(|| DVector::zeros(parameters.len()));
3607                    let weighted_gradient = parameter_gradient.map(|value| {
3608                        value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
3609                    });
3610                    PrecomputedCachedIntegralGradientTerm {
3611                        mul_node_index: descriptor.mul_node_index,
3612                        parameter_node_index: descriptor.parameter_node_index,
3613                        cache_node_index: descriptor.cache_node_index,
3614                        coefficient: descriptor.coefficient,
3615                        weighted_gradient,
3616                    }
3617                })
3618                .collect())
3619        }
3620    }
3621    fn evaluate_cached_weighted_value_sum_ir(
3622        &self,
3623        state: &CachedIntegralCacheState,
3624        amplitude_values: &[Complex64],
3625    ) -> f64 {
3626        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3627        let _ = state
3628            .expression_ir
3629            .evaluate_into(amplitude_values, &mut value_slots);
3630        state
3631            .values
3632            .iter()
3633            .map(|descriptor| {
3634                let parameter_factor = value_slots[descriptor.parameter_node_index];
3635                (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
3636                    .re
3637            })
3638            .sum()
3639    }
3640    fn evaluate_cached_weighted_value_sum_lowered(
3641        &self,
3642        state: &CachedIntegralCacheState,
3643        lowered_artifacts: &LoweredArtifactCacheState,
3644        amplitude_values: &[Complex64],
3645    ) -> Option<f64> {
3646        let max_slots = lowered_artifacts
3647            .lowered_parameter_factors
3648            .iter()
3649            .filter_map(|runtime| {
3650                runtime
3651                    .as_ref()
3652                    .and_then(|runtime| runtime.value_program())
3653                    .map(|program| program.scratch_slots())
3654            })
3655            .max()
3656            .unwrap_or(0);
3657        let mut value_slots = vec![Complex64::ZERO; max_slots];
3658        let mut total = 0.0;
3659        for (descriptor, runtime) in state
3660            .values
3661            .iter()
3662            .zip(lowered_artifacts.lowered_parameter_factors.iter())
3663        {
3664            let parameter_factor = runtime
3665                .as_ref()
3666                .and_then(|runtime| runtime.value_program())
3667                .map(|program| {
3668                    program.evaluate_into(
3669                        amplitude_values,
3670                        &mut value_slots[..program.scratch_slots()],
3671                    )
3672                })?;
3673            total +=
3674                (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
3675                    .re;
3676        }
3677        Some(total)
3678    }
3679    fn evaluate_cached_weighted_gradient_sum_ir(
3680        &self,
3681        state: &CachedIntegralCacheState,
3682        amplitude_values: &[Complex64],
3683        amplitude_gradients: &[DVector<Complex64>],
3684        grad_dim: usize,
3685    ) -> DVector<f64> {
3686        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3687        let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
3688        let _ = state.expression_ir.evaluate_gradient_into(
3689            amplitude_values,
3690            amplitude_gradients,
3691            &mut value_slots,
3692            &mut gradient_slots,
3693        );
3694        state
3695            .values
3696            .iter()
3697            .fold(DVector::zeros(grad_dim), |mut accum, descriptor| {
3698                let parameter_gradient = &gradient_slots[descriptor.parameter_node_index];
3699                let coefficient = descriptor.coefficient as f64;
3700                for (accum_item, gradient_item) in accum.iter_mut().zip(parameter_gradient.iter()) {
3701                    *accum_item +=
3702                        (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
3703                }
3704                accum
3705            })
3706    }
3707    fn evaluate_cached_weighted_gradient_sum_lowered(
3708        &self,
3709        state: &CachedIntegralCacheState,
3710        lowered_artifacts: &LoweredArtifactCacheState,
3711        amplitude_values: &[Complex64],
3712        amplitude_gradients: &[DVector<Complex64>],
3713        grad_dim: usize,
3714    ) -> Option<DVector<f64>> {
3715        let max_value_slots = lowered_artifacts
3716            .lowered_parameter_factors
3717            .iter()
3718            .filter_map(|runtime| {
3719                runtime
3720                    .as_ref()
3721                    .and_then(|runtime| runtime.gradient_program())
3722                    .map(|program| program.scratch_slots())
3723            })
3724            .max()
3725            .unwrap_or(0);
3726        let mut value_slots = vec![Complex64::ZERO; max_value_slots];
3727        let mut gradient_slots = vec![Complex64::ZERO; max_value_slots * grad_dim];
3728        let mut total = DVector::zeros(grad_dim);
3729        for (descriptor, runtime) in state
3730            .values
3731            .iter()
3732            .zip(lowered_artifacts.lowered_parameter_factors.iter())
3733        {
3734            let parameter_gradient = runtime
3735                .as_ref()
3736                .and_then(|runtime| runtime.gradient_program())
3737                .map(|program| {
3738                    program.evaluate_gradient_into_flat(
3739                        amplitude_values,
3740                        amplitude_gradients,
3741                        &mut value_slots[..program.scratch_slots()],
3742                        &mut gradient_slots[..program.scratch_slots() * grad_dim],
3743                        grad_dim,
3744                    )
3745                })?;
3746            let coefficient = descriptor.coefficient as f64;
3747            for (accum_item, gradient_item) in total.iter_mut().zip(parameter_gradient.iter()) {
3748                *accum_item += (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
3749            }
3750        }
3751        Some(total)
3752    }
3753    fn evaluate_residual_value_ir(
3754        &self,
3755        state: &CachedIntegralCacheState,
3756        amplitude_values: &[Complex64],
3757    ) -> Complex64 {
3758        let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
3759        for descriptor in &state.values {
3760            if descriptor.mul_node_index < zeroed_nodes.len() {
3761                zeroed_nodes[descriptor.mul_node_index] = true;
3762            }
3763        }
3764        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3765        state.expression_ir.evaluate_into_with_zeroed_nodes(
3766            amplitude_values,
3767            &mut value_slots,
3768            &zeroed_nodes,
3769        )
3770    }
3771    fn evaluate_residual_gradient_ir(
3772        &self,
3773        state: &CachedIntegralCacheState,
3774        amplitude_values: &[Complex64],
3775        amplitude_gradients: &[DVector<Complex64>],
3776        grad_dim: usize,
3777    ) -> DVector<Complex64> {
3778        let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
3779        for descriptor in &state.values {
3780            if descriptor.mul_node_index < zeroed_nodes.len() {
3781                zeroed_nodes[descriptor.mul_node_index] = true;
3782            }
3783        }
3784        let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3785        let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
3786        state
3787            .expression_ir
3788            .evaluate_gradient_into_with_zeroed_nodes(
3789                amplitude_values,
3790                amplitude_gradients,
3791                &mut value_slots,
3792                &mut gradient_slots,
3793                &zeroed_nodes,
3794            )
3795    }
3796
3797    fn evaluate_weighted_value_sum_local_components(
3798        &self,
3799        parameters: &[f64],
3800    ) -> LadduResult<(f64, f64)> {
3801        let resources = self.resources.read();
3802        let parameters = resources.parameter_map.assemble(parameters)?;
3803        let amplitude_len = self.amplitudes.len();
3804        let state = self.ensure_cached_integral_cache_state(&resources)?;
3805        let lowered_artifacts = self.active_lowered_artifacts();
3806        let residual_value_slot_count = lowered_artifacts
3807            .as_ref()
3808            .and_then(|artifacts| {
3809                artifacts
3810                    .residual_runtime
3811                    .as_ref()
3812                    .map(|runtime| runtime.value_program())
3813                    .map(|program| program.scratch_slots())
3814            })
3815            .unwrap_or_else(|| self.expression_slot_count());
3816        let residual_value_program = lowered_artifacts
3817            .as_ref()
3818            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
3819            .map(|runtime| runtime.value_program());
3820        let cached_parameter_indices = &state.execution_sets.cached_parameter_amplitudes;
3821        let residual_active_indices = &state.execution_sets.residual_amplitudes;
3822        debug_assert!(cached_parameter_indices.iter().all(|&index| resources
3823            .active
3824            .get(index)
3825            .copied()
3826            .unwrap_or(false)));
3827        debug_assert!(residual_active_indices.iter().all(|&index| resources
3828            .active
3829            .get(index)
3830            .copied()
3831            .unwrap_or(false)));
3832        let cached_value_sum = {
3833            if let Some(cache) = resources.caches.first() {
3834                let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3835                self.fill_amplitude_values(
3836                    &mut amplitude_values,
3837                    cached_parameter_indices,
3838                    &parameters,
3839                    cache,
3840                );
3841                lowered_artifacts
3842                    .as_ref()
3843                    .and_then(|artifacts| {
3844                        self.evaluate_cached_weighted_value_sum_lowered(
3845                            &state,
3846                            artifacts,
3847                            &amplitude_values,
3848                        )
3849                    })
3850                    .unwrap_or_else(|| {
3851                        self.evaluate_cached_weighted_value_sum_ir(&state, &amplitude_values)
3852                    })
3853            } else {
3854                0.0
3855            }
3856        };
3857
3858        #[cfg(feature = "rayon")]
3859        let residual_sum: f64 = {
3860            resources
3861                .caches
3862                .par_iter()
3863                .zip(self.dataset.events_local().par_iter())
3864                .map_init(
3865                    || {
3866                        (
3867                            vec![Complex64::ZERO; amplitude_len],
3868                            vec![Complex64::ZERO; residual_value_slot_count],
3869                        )
3870                    },
3871                    |(amplitude_values, value_slots), (cache, event)| {
3872                        self.fill_amplitude_values(
3873                            amplitude_values,
3874                            residual_active_indices,
3875                            &parameters,
3876                            cache,
3877                        );
3878                        {
3879                            let value = residual_value_program
3880                                .as_ref()
3881                                .map(|program| {
3882                                    program.evaluate_into(
3883                                        amplitude_values,
3884                                        &mut value_slots[..program.scratch_slots()],
3885                                    )
3886                                })
3887                                .unwrap_or_else(|| {
3888                                    self.evaluate_residual_value_ir(&state, amplitude_values)
3889                                });
3890                            event.weight * value.re
3891                        }
3892                    },
3893                )
3894                .sum()
3895        };
3896
3897        #[cfg(not(feature = "rayon"))]
3898        let residual_sum: f64 = {
3899            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3900            let mut value_slots = vec![Complex64::ZERO; residual_value_slot_count];
3901            resources
3902                .caches
3903                .iter()
3904                .zip(self.dataset.events_local().iter())
3905                .map(|(cache, event)| {
3906                    self.fill_amplitude_values(
3907                        &mut amplitude_values,
3908                        &residual_active_indices,
3909                        &parameters,
3910                        cache,
3911                    );
3912                    {
3913                        let value = residual_value_program
3914                            .as_ref()
3915                            .map(|program| {
3916                                program.evaluate_into(
3917                                    &amplitude_values,
3918                                    &mut value_slots[..program.scratch_slots()],
3919                                )
3920                            })
3921                            .unwrap_or_else(|| {
3922                                self.evaluate_residual_value_ir(&state, &amplitude_values)
3923                            });
3924                        event.weight * value.re
3925                    }
3926                })
3927                .sum()
3928        };
3929        Ok((residual_sum, cached_value_sum))
3930    }
3931
3932    /// Weighted sum over local events of the real expression value.
3933    ///
3934    /// This returns `sum_e(weight_e * Re(L_e))`.
3935    pub fn evaluate_weighted_value_sum_local(&self, parameters: &[f64]) -> LadduResult<f64> {
3936        let (residual_sum, cached_value_sum) =
3937            self.evaluate_weighted_value_sum_local_components(parameters)?;
3938        Ok(residual_sum + cached_value_sum)
3939    }
3940
3941    #[cfg(feature = "mpi")]
3942    /// Weighted sum over all ranks of the real expression value.
3943    ///
3944    /// This returns `sum_{r,e}(weight_{r,e} * Re(L_{r,e}))`.
3945    pub fn evaluate_weighted_value_sum_mpi(
3946        &self,
3947        parameters: &[f64],
3948        world: &SimpleCommunicator,
3949    ) -> LadduResult<f64> {
3950        let (residual_sum_local, cached_value_sum_local) =
3951            self.evaluate_weighted_value_sum_local_components(parameters)?;
3952        let mut residual_sum = 0.0;
3953        world.all_reduce_into(
3954            &residual_sum_local,
3955            &mut residual_sum,
3956            mpi::collective::SystemOperation::sum(),
3957        );
3958        let mut cached_value_sum = 0.0;
3959        world.all_reduce_into(
3960            &cached_value_sum_local,
3961            &mut cached_value_sum,
3962            mpi::collective::SystemOperation::sum(),
3963        );
3964        Ok(residual_sum + cached_value_sum)
3965    }
3966
3967    /// Weighted sum over local events of the real gradient of the expression.
3968    ///
3969    /// This returns `sum_e(weight_e * Re(dL_e/dp))` for all free parameters.
3970    fn evaluate_weighted_gradient_sum_local_components(
3971        &self,
3972        parameters: &[f64],
3973    ) -> LadduResult<(DVector<f64>, DVector<f64>)> {
3974        let resources = self.resources.read();
3975        let parameters = resources.parameter_map.assemble(parameters)?;
3976        let amplitude_len = self.amplitudes.len();
3977        let grad_dim = parameters.len();
3978        let state = self.ensure_cached_integral_cache_state(&resources)?;
3979        let lowered_artifacts = self.active_lowered_artifacts();
3980        let active_index_set = resources.active_indices();
3981        let cached_parameter_indices = state
3982            .execution_sets
3983            .cached_parameter_amplitudes
3984            .iter()
3985            .copied()
3986            .filter(|index| active_index_set.binary_search(index).is_ok())
3987            .collect::<Vec<_>>();
3988        let residual_active_indices = state
3989            .execution_sets
3990            .residual_amplitudes
3991            .iter()
3992            .copied()
3993            .filter(|index| active_index_set.binary_search(index).is_ok())
3994            .collect::<Vec<_>>();
3995        let mut cached_parameter_mask = vec![false; amplitude_len];
3996        for &index in &cached_parameter_indices {
3997            cached_parameter_mask[index] = true;
3998        }
3999        let mut residual_active_mask = vec![false; amplitude_len];
4000        for &index in &residual_active_indices {
4001            residual_active_mask[index] = true;
4002        }
4003        let residual_gradient_program = lowered_artifacts
4004            .as_ref()
4005            .and_then(|artifacts| artifacts.residual_runtime.as_ref())
4006            .map(|runtime| runtime.gradient_program());
4007        let residual_gradient_slot_count = residual_gradient_program
4008            .as_ref()
4009            .map(|program| program.scratch_slots())
4010            .unwrap_or_else(|| state.expression_ir.node_count());
4011        let cached_term_sum = {
4012            if let Some(cache) = resources.caches.first() {
4013                let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4014                self.fill_amplitude_values(
4015                    &mut amplitude_values,
4016                    &cached_parameter_indices,
4017                    &parameters,
4018                    cache,
4019                );
4020                let mut amplitude_gradients = (0..amplitude_len)
4021                    .map(|_| DVector::zeros(grad_dim))
4022                    .collect::<Vec<_>>();
4023                self.fill_amplitude_gradients(
4024                    &mut amplitude_gradients,
4025                    &cached_parameter_mask,
4026                    &parameters,
4027                    cache,
4028                );
4029                lowered_artifacts
4030                    .as_ref()
4031                    .and_then(|artifacts| {
4032                        self.evaluate_cached_weighted_gradient_sum_lowered(
4033                            &state,
4034                            artifacts,
4035                            &amplitude_values,
4036                            &amplitude_gradients,
4037                            grad_dim,
4038                        )
4039                    })
4040                    .unwrap_or_else(|| {
4041                        self.evaluate_cached_weighted_gradient_sum_ir(
4042                            &state,
4043                            &amplitude_values,
4044                            &amplitude_gradients,
4045                            grad_dim,
4046                        )
4047                    })
4048            } else {
4049                DVector::zeros(grad_dim)
4050            }
4051        };
4052
4053        #[cfg(feature = "rayon")]
4054        let residual_sum = {
4055            resources
4056                .caches
4057                .par_iter()
4058                .zip(self.dataset.events_local().par_iter())
4059                .map_init(
4060                    || {
4061                        (
4062                            vec![Complex64::ZERO; amplitude_len],
4063                            vec![DVector::zeros(grad_dim); amplitude_len],
4064                            vec![Complex64::ZERO; residual_gradient_slot_count],
4065                            vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim],
4066                        )
4067                    },
4068                    |(amplitude_values, gradient_values, value_slots, gradient_slots),
4069                     (cache, event)| {
4070                        self.fill_amplitude_values_and_gradients(
4071                            amplitude_values,
4072                            gradient_values,
4073                            &residual_active_indices,
4074                            &residual_active_mask,
4075                            &parameters,
4076                            cache,
4077                        );
4078                        let gradient = residual_gradient_program
4079                            .as_ref()
4080                            .map(|program| {
4081                                program.evaluate_gradient_into_flat(
4082                                    amplitude_values,
4083                                    gradient_values,
4084                                    value_slots,
4085                                    gradient_slots,
4086                                    grad_dim,
4087                                )
4088                            })
4089                            .unwrap_or_else(|| {
4090                                self.evaluate_residual_gradient_ir(
4091                                    &state,
4092                                    amplitude_values,
4093                                    gradient_values,
4094                                    grad_dim,
4095                                )
4096                            });
4097                        gradient.map(|value| value.re).scale(event.weight)
4098                    },
4099                )
4100                .reduce(
4101                    || DVector::zeros(grad_dim),
4102                    |mut accum, value| {
4103                        accum += value;
4104                        accum
4105                    },
4106                )
4107        };
4108
4109        #[cfg(not(feature = "rayon"))]
4110        let residual_sum = {
4111            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4112            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4113            let mut value_slots = vec![Complex64::ZERO; residual_gradient_slot_count];
4114            let mut gradient_slots = vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim];
4115            resources
4116                .caches
4117                .iter()
4118                .zip(self.dataset.events_local().iter())
4119                .map(|(cache, event)| {
4120                    self.fill_amplitude_values_and_gradients(
4121                        &mut amplitude_values,
4122                        &mut gradient_values,
4123                        &residual_active_indices,
4124                        &residual_active_mask,
4125                        &parameters,
4126                        cache,
4127                    );
4128                    let gradient = residual_gradient_program
4129                        .as_ref()
4130                        .map(|program| {
4131                            program.evaluate_gradient_into_flat(
4132                                &amplitude_values,
4133                                &gradient_values,
4134                                &mut value_slots,
4135                                &mut gradient_slots,
4136                                grad_dim,
4137                            )
4138                        })
4139                        .unwrap_or_else(|| {
4140                            self.evaluate_residual_gradient_ir(
4141                                &state,
4142                                &amplitude_values,
4143                                &gradient_values,
4144                                grad_dim,
4145                            )
4146                        });
4147                    gradient.map(|value| value.re).scale(event.weight)
4148                })
4149                .sum()
4150        };
4151        Ok((residual_sum, cached_term_sum))
4152    }
4153
4154    /// Weighted sum over local events of the real gradient of the expression.
4155    ///
4156    /// This returns `sum_e(weight_e * Re(dL_e/dp))` for all free parameters.
4157    pub fn evaluate_weighted_gradient_sum_local(
4158        &self,
4159        parameters: &[f64],
4160    ) -> LadduResult<DVector<f64>> {
4161        let (residual_sum, cached_term_sum) =
4162            self.evaluate_weighted_gradient_sum_local_components(parameters)?;
4163        Ok(residual_sum + cached_term_sum)
4164    }
4165
4166    #[cfg(feature = "mpi")]
4167    /// Weighted sum over all ranks of the real gradient of the expression.
4168    ///
4169    /// This returns `sum_{r,e}(weight_{r,e} * Re(dL_{r,e}/dp))`.
4170    pub fn evaluate_weighted_gradient_sum_mpi(
4171        &self,
4172        parameters: &[f64],
4173        world: &SimpleCommunicator,
4174    ) -> LadduResult<DVector<f64>> {
4175        let (residual_sum_local, cached_term_sum_local) =
4176            self.evaluate_weighted_gradient_sum_local_components(parameters)?;
4177        let mut residual_sum = vec![0.0; residual_sum_local.len()];
4178        world.all_reduce_into(
4179            residual_sum_local.as_slice(),
4180            &mut residual_sum,
4181            mpi::collective::SystemOperation::sum(),
4182        );
4183        let mut cached_term_sum = vec![0.0; cached_term_sum_local.len()];
4184        world.all_reduce_into(
4185            cached_term_sum_local.as_slice(),
4186            &mut cached_term_sum,
4187            mpi::collective::SystemOperation::sum(),
4188        );
4189        let mut total = DVector::from_vec(residual_sum);
4190        total += DVector::from_vec(cached_term_sum);
4191        Ok(total)
4192    }
4193
4194    pub fn evaluate_expression_value_with_scratch(
4195        &self,
4196        amplitude_values: &[Complex64],
4197        scratch: &mut [Complex64],
4198    ) -> Complex64 {
4199        self.evaluate_expression_runtime_value_with_scratch(amplitude_values, scratch)
4200    }
4201
4202    pub fn evaluate_expression_gradient_with_scratch(
4203        &self,
4204        amplitude_values: &[Complex64],
4205        gradient_values: &[DVector<Complex64>],
4206        value_scratch: &mut [Complex64],
4207        gradient_scratch: &mut [DVector<Complex64>],
4208    ) -> DVector<Complex64> {
4209        self.evaluate_expression_runtime_gradient_with_scratch(
4210            amplitude_values,
4211            gradient_values,
4212            value_scratch,
4213            gradient_scratch,
4214        )
4215    }
4216
4217    pub fn evaluate_expression_value_gradient_with_scratch(
4218        &self,
4219        amplitude_values: &[Complex64],
4220        gradient_values: &[DVector<Complex64>],
4221        value_scratch: &mut [Complex64],
4222        gradient_scratch: &mut [DVector<Complex64>],
4223    ) -> (Complex64, DVector<Complex64>) {
4224        self.evaluate_expression_runtime_value_gradient_with_scratch(
4225            amplitude_values,
4226            gradient_values,
4227            value_scratch,
4228            gradient_scratch,
4229        )
4230    }
4231
4232    pub fn evaluate_expression_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
4233        self.evaluate_expression_runtime_value(amplitude_values)
4234    }
4235
4236    pub fn evaluate_expression_gradient(
4237        &self,
4238        amplitude_values: &[Complex64],
4239        gradient_values: &[DVector<Complex64>],
4240    ) -> DVector<Complex64> {
4241        self.evaluate_expression_runtime_gradient(amplitude_values, gradient_values)
4242    }
4243
4244    /// Get the list of parameter names in the order they appear in the [`Evaluator::evaluate`]
4245    /// method.
4246    pub fn parameters(&self) -> Vec<String> {
4247        self.resources.read().parameter_names()
4248    }
4249
4250    /// Get the list of free parameter names.
4251    pub fn free_parameters(&self) -> Vec<String> {
4252        self.resources.read().free_parameter_names()
4253    }
4254
4255    /// Get the list of fixed parameter names.
4256    pub fn fixed_parameters(&self) -> Vec<String> {
4257        self.resources.read().fixed_parameter_names()
4258    }
4259
4260    /// Number of free parameters.
4261    pub fn n_free(&self) -> usize {
4262        self.resources.read().n_free_parameters()
4263    }
4264
4265    /// Number of fixed parameters.
4266    pub fn n_fixed(&self) -> usize {
4267        self.resources.read().n_fixed_parameters()
4268    }
4269
4270    /// Total number of parameters.
4271    pub fn n_parameters(&self) -> usize {
4272        self.resources.read().n_parameters()
4273    }
4274
4275    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
4276        self.resources.read().fix_parameter(name, value)
4277    }
4278
4279    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
4280        self.resources.read().free_parameter(name)
4281    }
4282
4283    pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
4284        self.resources.write().rename_parameter(old, new)
4285    }
4286
4287    pub fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
4288        self.resources.write().rename_parameters(mapping)
4289    }
4290
4291    /// Activate an [`Amplitude`] by name, skipping missing entries.
4292    pub fn activate<T: AsRef<str>>(&self, name: T) {
4293        self.resources.write().activate(name);
4294        self.refresh_runtime_specializations();
4295    }
4296    /// Activate an [`Amplitude`] by name and return an error if it is missing.
4297    pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
4298        self.resources.write().activate_strict(name)?;
4299        self.refresh_runtime_specializations();
4300        Ok(())
4301    }
4302
4303    /// Activate several [`Amplitude`]s by name, skipping missing entries.
4304    pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
4305        self.resources.write().activate_many(names);
4306        self.refresh_runtime_specializations();
4307    }
4308    /// Activate several [`Amplitude`]s by name and return an error if any are missing.
4309    pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
4310        self.resources.write().activate_many_strict(names)?;
4311        self.refresh_runtime_specializations();
4312        Ok(())
4313    }
4314
4315    /// Activate all registered [`Amplitude`]s.
4316    pub fn activate_all(&self) {
4317        self.resources.write().activate_all();
4318        self.refresh_runtime_specializations();
4319    }
4320
4321    /// Dectivate an [`Amplitude`] by name, skipping missing entries.
4322    pub fn deactivate<T: AsRef<str>>(&self, name: T) {
4323        self.resources.write().deactivate(name);
4324        self.refresh_runtime_specializations();
4325    }
4326
4327    /// Dectivate an [`Amplitude`] by name and return an error if it is missing.
4328    pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
4329        self.resources.write().deactivate_strict(name)?;
4330        self.refresh_runtime_specializations();
4331        Ok(())
4332    }
4333
4334    /// Deactivate several [`Amplitude`]s by name, skipping missing entries.
4335    pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
4336        self.resources.write().deactivate_many(names);
4337        self.refresh_runtime_specializations();
4338    }
4339    /// Dectivate several [`Amplitude`]s by name and return an error if any are missing.
4340    pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
4341        self.resources.write().deactivate_many_strict(names)?;
4342        self.refresh_runtime_specializations();
4343        Ok(())
4344    }
4345
4346    /// Deactivate all registered [`Amplitude`]s.
4347    pub fn deactivate_all(&self) {
4348        self.resources.write().deactivate_all();
4349        self.refresh_runtime_specializations();
4350    }
4351
4352    /// Isolate an [`Amplitude`] by name (deactivate the rest), skipping missing entries.
4353    pub fn isolate<T: AsRef<str>>(&self, name: T) {
4354        self.resources.write().isolate(name);
4355        self.refresh_runtime_specializations();
4356    }
4357
4358    /// Isolate an [`Amplitude`] by name (deactivate the rest) and return an error if it is missing.
4359    pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
4360        self.resources.write().isolate_strict(name)?;
4361        self.refresh_runtime_specializations();
4362        Ok(())
4363    }
4364
4365    /// Isolate several [`Amplitude`]s by name (deactivate the rest), skipping missing entries.
4366    pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
4367        self.resources.write().isolate_many(names);
4368        self.refresh_runtime_specializations();
4369    }
4370
4371    /// Isolate several [`Amplitude`]s by name (deactivate the rest) and return an error if any are missing.
4372    pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
4373        self.resources.write().isolate_many_strict(names)?;
4374        self.refresh_runtime_specializations();
4375        Ok(())
4376    }
4377
4378    /// Return a copy of the current active-amplitude mask.
4379    pub fn active_mask(&self) -> Vec<bool> {
4380        self.resources.read().active.clone()
4381    }
4382
4383    /// Apply a precomputed active-amplitude mask.
4384    pub fn set_active_mask(&self, mask: &[bool]) -> LadduResult<()> {
4385        let resources = {
4386            let mut resources = self.resources.write();
4387            if mask.len() != resources.active.len() {
4388                return Err(LadduError::LengthMismatch {
4389                    context: "active amplitude mask".to_string(),
4390                    expected: resources.active.len(),
4391                    actual: mask.len(),
4392                });
4393            }
4394            resources.active.clone_from_slice(mask);
4395            resources.refresh_active_indices();
4396            resources.clone()
4397        };
4398        self.rebuild_runtime_specializations(&resources);
4399        Ok(())
4400    }
4401
4402    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
4403    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
4404    ///
4405    /// # Notes
4406    ///
4407    /// This method is not intended to be called in analyses but rather in writing methods
4408    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
4409    pub fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
4410        let resources = self.resources.read();
4411        let parameters = resources.parameter_map.assemble(parameters)?;
4412        let amplitude_len = self.amplitudes.len();
4413        let active_indices = resources.active_indices().to_vec();
4414        let slot_count = self.expression_value_slot_count();
4415        let program_snapshot = self.expression_value_program_snapshot();
4416        #[cfg(feature = "rayon")]
4417        {
4418            Ok(resources
4419                .caches
4420                .par_iter()
4421                .map_init(
4422                    || {
4423                        (
4424                            vec![Complex64::ZERO; amplitude_len],
4425                            vec![Complex64::ZERO; slot_count],
4426                        )
4427                    },
4428                    |(amplitude_values, expr_slots), cache| {
4429                        self.fill_amplitude_values(
4430                            amplitude_values,
4431                            &active_indices,
4432                            &parameters,
4433                            cache,
4434                        );
4435                        self.evaluate_expression_value_with_program_snapshot(
4436                            &program_snapshot,
4437                            amplitude_values,
4438                            expr_slots,
4439                        )
4440                    },
4441                )
4442                .collect())
4443        }
4444        #[cfg(not(feature = "rayon"))]
4445        {
4446            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4447            let mut expr_slots = vec![Complex64::ZERO; slot_count];
4448            Ok(resources
4449                .caches
4450                .iter()
4451                .map(|cache| {
4452                    self.fill_amplitude_values(
4453                        &mut amplitude_values,
4454                        &active_indices,
4455                        &parameters,
4456                        cache,
4457                    );
4458                    self.evaluate_expression_value_with_program_snapshot(
4459                        &program_snapshot,
4460                        &amplitude_values,
4461                        &mut expr_slots,
4462                    )
4463                })
4464                .collect())
4465        }
4466    }
4467
4468    /// Evaluate local events using an explicit active-amplitude mask without mutating evaluator state.
4469    pub fn evaluate_local_with_active_mask(
4470        &self,
4471        parameters: &[f64],
4472        active_mask: &[bool],
4473    ) -> LadduResult<Vec<Complex64>> {
4474        let resources = self.resources.read();
4475        if active_mask.len() != resources.active.len() {
4476            return Err(LadduError::LengthMismatch {
4477                context: "active amplitude mask".to_string(),
4478                expected: resources.active.len(),
4479                actual: active_mask.len(),
4480            });
4481        }
4482        let parameters = resources.parameter_map.assemble(parameters)?;
4483        let amplitude_len = self.amplitudes.len();
4484        let active_indices = active_mask
4485            .iter()
4486            .enumerate()
4487            .filter_map(|(index, &active)| if active { Some(index) } else { None })
4488            .collect::<Vec<_>>();
4489        let program_snapshot =
4490            self.expression_value_program_snapshot_for_active_mask(active_mask)?;
4491        let slot_count = self.expression_value_program_snapshot_slot_count(&program_snapshot);
4492        #[cfg(feature = "rayon")]
4493        {
4494            Ok(resources
4495                .caches
4496                .par_iter()
4497                .map_init(
4498                    || {
4499                        (
4500                            vec![Complex64::ZERO; amplitude_len],
4501                            vec![Complex64::ZERO; slot_count],
4502                        )
4503                    },
4504                    |(amplitude_values, expr_slots), cache| {
4505                        self.fill_amplitude_values(
4506                            amplitude_values,
4507                            &active_indices,
4508                            &parameters,
4509                            cache,
4510                        );
4511                        self.evaluate_expression_value_with_program_snapshot(
4512                            &program_snapshot,
4513                            amplitude_values,
4514                            expr_slots,
4515                        )
4516                    },
4517                )
4518                .collect())
4519        }
4520        #[cfg(not(feature = "rayon"))]
4521        {
4522            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4523            let mut expr_slots = vec![Complex64::ZERO; slot_count];
4524            Ok(resources
4525                .caches
4526                .iter()
4527                .map(|cache| {
4528                    self.fill_amplitude_values(
4529                        &mut amplitude_values,
4530                        &active_indices,
4531                        &parameters,
4532                        cache,
4533                    );
4534                    self.evaluate_expression_value_with_program_snapshot(
4535                        &program_snapshot,
4536                        &amplitude_values,
4537                        &mut expr_slots,
4538                    )
4539                })
4540                .collect())
4541        }
4542    }
4543
4544    /// Evaluate the stored expression over local events using a reusable execution context.
4545    #[cfg(feature = "execution-context-prototype")]
4546    pub fn evaluate_local_with_ctx(
4547        &self,
4548        parameters: &[f64],
4549        execution_context: &ExecutionContext,
4550    ) -> Vec<Complex64> {
4551        let resources = self.resources.read();
4552        let parameters = resources
4553            .parameter_map
4554            .assemble(parameters)
4555            .expect("parameter slice must match evaluator resources");
4556        let amplitude_len = self.amplitudes.len();
4557        let active_indices = resources.active_indices().to_vec();
4558        let slot_count = self.expression_value_slot_count();
4559        let program_snapshot = self.expression_value_program_snapshot();
4560        #[cfg(feature = "rayon")]
4561        {
4562            if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
4563                return execution_context.install(|| {
4564                    resources
4565                        .caches
4566                        .par_iter()
4567                        .map_init(
4568                            || {
4569                                (
4570                                    vec![Complex64::ZERO; amplitude_len],
4571                                    vec![Complex64::ZERO; slot_count],
4572                                )
4573                            },
4574                            |(amplitude_values, expr_slots), cache| {
4575                                self.fill_amplitude_values(
4576                                    amplitude_values,
4577                                    &active_indices,
4578                                    &parameters,
4579                                    cache,
4580                                );
4581                                self.evaluate_expression_value_with_program_snapshot(
4582                                    &program_snapshot,
4583                                    amplitude_values,
4584                                    expr_slots,
4585                                )
4586                            },
4587                        )
4588                        .collect()
4589                });
4590            }
4591        }
4592        execution_context.with_scratch(|scratch| {
4593            let (amplitude_values, expr_slots) =
4594                scratch.reserve_value_workspaces(amplitude_len, slot_count);
4595            resources
4596                .caches
4597                .iter()
4598                .map(|cache| {
4599                    self.fill_amplitude_values(
4600                        amplitude_values,
4601                        &active_indices,
4602                        &parameters,
4603                        cache,
4604                    );
4605                    self.evaluate_expression_value_with_program_snapshot(
4606                        &program_snapshot,
4607                        amplitude_values,
4608                        expr_slots,
4609                    )
4610                })
4611                .collect()
4612        })
4613    }
4614
4615    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
4616    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
4617    ///
4618    /// # Notes
4619    ///
4620    /// This method is not intended to be called in analyses but rather in writing methods
4621    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate`] instead.
4622    #[cfg(feature = "mpi")]
4623    fn evaluate_mpi(
4624        &self,
4625        parameters: &[f64],
4626        world: &SimpleCommunicator,
4627    ) -> LadduResult<Vec<Complex64>> {
4628        let local_evaluation = self.evaluate_local(parameters)?;
4629        let n_events = self.dataset.n_events();
4630        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
4631        let (counts, displs) = world.get_counts_displs(n_events);
4632        {
4633            // NOTE: gather is required here because the public MPI API returns full per-event outputs.
4634            // Do not replace with all-reduce unless semantics change to scalar aggregates only.
4635            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4636            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
4637        }
4638        Ok(buffer)
4639    }
4640
4641    #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
4642    fn evaluate_mpi_with_ctx(
4643        &self,
4644        parameters: &[f64],
4645        world: &SimpleCommunicator,
4646        execution_context: &ExecutionContext,
4647    ) -> Vec<Complex64> {
4648        let local_evaluation = self.evaluate_local_with_ctx(parameters, execution_context);
4649        let n_events = self.dataset.n_events();
4650        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
4651        let (counts, displs) = world.get_counts_displs(n_events);
4652        {
4653            // NOTE: gather is required here because the public MPI API returns full per-event outputs.
4654            // Do not replace with all-reduce unless semantics change to scalar aggregates only.
4655            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4656            world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
4657        }
4658        buffer
4659    }
4660
4661    /// Evaluate the stored [`Expression`] over the events in the [`Dataset`] stored by the
4662    /// [`Evaluator`] with the given values for free parameters.
4663    pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
4664        #[cfg(feature = "mpi")]
4665        {
4666            if let Some(world) = crate::mpi::get_world() {
4667                return self.evaluate_mpi(parameters, &world);
4668            }
4669        }
4670        self.evaluate_local(parameters)
4671    }
4672
4673    /// Evaluate the stored expression with a reusable execution context.
4674    ///
4675    /// This is intended for repeated calls with the same context instance.
4676    /// Thread behavior follows [`ThreadPolicy`](crate::ThreadPolicy) configured on
4677    /// [`ExecutionContext`](crate::ExecutionContext).
4678    #[cfg(feature = "execution-context-prototype")]
4679    pub fn evaluate_with_ctx(
4680        &self,
4681        parameters: &[f64],
4682        execution_context: &ExecutionContext,
4683    ) -> Vec<Complex64> {
4684        #[cfg(feature = "mpi")]
4685        {
4686            if let Some(world) = crate::mpi::get_world() {
4687                return self.evaluate_mpi_with_ctx(parameters, &world, execution_context);
4688            }
4689        }
4690        self.evaluate_local_with_ctx(parameters, execution_context)
4691    }
4692
4693    /// See [`Evaluator::evaluate_local`]. This method evaluates over a subset of events rather
4694    /// than all events in the total dataset.
4695    pub fn evaluate_batch_local(
4696        &self,
4697        parameters: &[f64],
4698        indices: &[usize],
4699    ) -> LadduResult<Vec<Complex64>> {
4700        let resources = self.resources.read();
4701        let parameters = resources.parameter_map.assemble(parameters)?;
4702        let amplitude_len = self.amplitudes.len();
4703        let active_indices = resources.active_indices().to_vec();
4704        let slot_count = self.expression_value_slot_count();
4705        let program_snapshot = self.expression_value_program_snapshot();
4706        #[cfg(feature = "rayon")]
4707        {
4708            Ok(indices
4709                .par_iter()
4710                .map_init(
4711                    || {
4712                        (
4713                            vec![Complex64::ZERO; amplitude_len],
4714                            vec![Complex64::ZERO; slot_count],
4715                        )
4716                    },
4717                    |(amplitude_values, expr_slots), &idx| {
4718                        let cache = &resources.caches[idx];
4719                        self.fill_amplitude_values(
4720                            amplitude_values,
4721                            &active_indices,
4722                            &parameters,
4723                            cache,
4724                        );
4725                        self.evaluate_expression_value_with_program_snapshot(
4726                            &program_snapshot,
4727                            amplitude_values,
4728                            expr_slots,
4729                        )
4730                    },
4731                )
4732                .collect())
4733        }
4734        #[cfg(not(feature = "rayon"))]
4735        {
4736            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4737            let mut expr_slots = vec![Complex64::ZERO; slot_count];
4738            Ok(indices
4739                .iter()
4740                .map(|&idx| {
4741                    let cache = &resources.caches[idx];
4742                    self.fill_amplitude_values(
4743                        &mut amplitude_values,
4744                        &active_indices,
4745                        &parameters,
4746                        cache,
4747                    );
4748                    self.evaluate_expression_value_with_program_snapshot(
4749                        &program_snapshot,
4750                        &amplitude_values,
4751                        &mut expr_slots,
4752                    )
4753                })
4754                .collect())
4755        }
4756    }
4757
4758    /// See [`Evaluator::evaluate_mpi`]. This method evaluates over a subset of events rather
4759    /// than all events in the total dataset.
4760    #[cfg(feature = "mpi")]
4761    fn evaluate_batch_mpi(
4762        &self,
4763        parameters: &[f64],
4764        indices: &[usize],
4765        world: &SimpleCommunicator,
4766    ) -> LadduResult<Vec<Complex64>> {
4767        let total = self.dataset.n_events();
4768        let locals = world.locals_from_globals(indices, total);
4769        let local_evaluation = self.evaluate_batch_local(parameters, &locals)?;
4770        Ok(world.all_gather_batched_partitioned(&local_evaluation, indices, total, None))
4771    }
4772
4773    /// Evaluate the stored [`Expression`] over a subset of events in the [`Dataset`] stored by the
4774    /// [`Evaluator`] with the given values for free parameters. See also [`Evaluator::evaluate`].
4775    pub fn evaluate_batch(
4776        &self,
4777        parameters: &[f64],
4778        indices: &[usize],
4779    ) -> LadduResult<Vec<Complex64>> {
4780        #[cfg(feature = "mpi")]
4781        {
4782            if let Some(world) = crate::mpi::get_world() {
4783                return self.evaluate_batch_mpi(parameters, indices, &world);
4784            }
4785        }
4786        self.evaluate_batch_local(parameters, indices)
4787    }
4788
4789    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
4790    /// [`Evaluator`] with the given values for free parameters (non-MPI version).
4791    ///
4792    /// # Notes
4793    ///
4794    /// This method is not intended to be called in analyses but rather in writing methods
4795    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
4796    pub fn evaluate_gradient_local(
4797        &self,
4798        parameters: &[f64],
4799    ) -> LadduResult<Vec<DVector<Complex64>>> {
4800        let resources = self.resources.read();
4801        let parameters = resources.parameter_map.assemble(parameters)?;
4802        let amplitude_len = self.amplitudes.len();
4803        let grad_dim = parameters.len();
4804        let active_indices = resources.active_indices().to_vec();
4805        let lowered_runtime = self.lowered_runtime();
4806        let gradient_program = lowered_runtime.gradient_program();
4807        let slot_count = self.expression_gradient_slot_count();
4808        #[cfg(feature = "rayon")]
4809        {
4810            Ok(resources
4811                .caches
4812                .par_iter()
4813                .map_init(
4814                    || {
4815                        (
4816                            vec![Complex64::ZERO; amplitude_len],
4817                            vec![DVector::zeros(grad_dim); amplitude_len],
4818                            vec![Complex64::ZERO; slot_count],
4819                            vec![Complex64::ZERO; slot_count * grad_dim],
4820                        )
4821                    },
4822                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
4823                        self.fill_amplitude_values_and_gradients(
4824                            amplitude_values,
4825                            gradient_values,
4826                            &active_indices,
4827                            &resources.active,
4828                            &parameters,
4829                            cache,
4830                        );
4831                        gradient_program.evaluate_gradient_into_flat(
4832                            amplitude_values,
4833                            gradient_values,
4834                            value_slots,
4835                            gradient_slots,
4836                            grad_dim,
4837                        )
4838                    },
4839                )
4840                .collect())
4841        }
4842        #[cfg(not(feature = "rayon"))]
4843        {
4844            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4845            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4846            let mut value_slots = vec![Complex64::ZERO; slot_count];
4847            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4848            Ok(resources
4849                .caches
4850                .iter()
4851                .map(|cache| {
4852                    self.fill_amplitude_values_and_gradients(
4853                        &mut amplitude_values,
4854                        &mut gradient_values,
4855                        &active_indices,
4856                        &resources.active,
4857                        &parameters,
4858                        cache,
4859                    );
4860                    gradient_program.evaluate_gradient_into_flat(
4861                        &amplitude_values,
4862                        &gradient_values,
4863                        &mut value_slots,
4864                        &mut gradient_slots,
4865                        grad_dim,
4866                    )
4867                })
4868                .collect())
4869        }
4870    }
4871
4872    /// Evaluate the gradient over local events using a reusable execution context.
4873    #[cfg(feature = "execution-context-prototype")]
4874    pub fn evaluate_gradient_local_with_ctx(
4875        &self,
4876        parameters: &[f64],
4877        execution_context: &ExecutionContext,
4878    ) -> Vec<DVector<Complex64>> {
4879        let resources = self.resources.read();
4880        let parameters = resources
4881            .parameter_map
4882            .assemble(parameters)
4883            .expect("parameter slice must match evaluator resources");
4884        let amplitude_len = self.amplitudes.len();
4885        let grad_dim = parameters.len();
4886        let active_indices = resources.active_indices().to_vec();
4887        let slot_count = self.expression_slot_count();
4888        #[cfg(feature = "rayon")]
4889        {
4890            if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
4891                return execution_context.install(|| {
4892                    resources
4893                        .caches
4894                        .par_iter()
4895                        .map_init(
4896                            || {
4897                                (
4898                                    vec![Complex64::ZERO; amplitude_len],
4899                                    vec![DVector::zeros(grad_dim); amplitude_len],
4900                                    vec![Complex64::ZERO; slot_count],
4901                                    vec![DVector::zeros(grad_dim); slot_count],
4902                                )
4903                            },
4904                            |(amplitude_values, gradient_values, value_slots, gradient_slots),
4905                             cache| {
4906                                self.evaluate_cache_gradient_with_scratch(
4907                                    amplitude_values,
4908                                    gradient_values,
4909                                    value_slots,
4910                                    gradient_slots,
4911                                    &active_indices,
4912                                    &resources.active,
4913                                    &parameters,
4914                                    cache,
4915                                )
4916                            },
4917                        )
4918                        .collect()
4919                });
4920            }
4921        }
4922        execution_context.with_scratch(|scratch| {
4923            let (amplitude_values, value_slots, gradient_values, gradient_slots) =
4924                scratch.reserve_gradient_workspaces(amplitude_len, slot_count, grad_dim);
4925            resources
4926                .caches
4927                .iter()
4928                .map(|cache| {
4929                    self.evaluate_cache_gradient_with_scratch(
4930                        amplitude_values,
4931                        gradient_values,
4932                        value_slots,
4933                        gradient_slots,
4934                        &active_indices,
4935                        &resources.active,
4936                        &parameters,
4937                        cache,
4938                    )
4939                })
4940                .collect()
4941        })
4942    }
4943
4944    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
4945    /// [`Evaluator`] with the given values for free parameters (MPI-compatible version).
4946    ///
4947    /// # Notes
4948    ///
4949    /// This method is not intended to be called in analyses but rather in writing methods
4950    /// that have `mpi`-feature-gated versions. Most users will want to call [`Evaluator::evaluate_gradient`] instead.
4951    #[cfg(feature = "mpi")]
4952    fn evaluate_gradient_mpi(
4953        &self,
4954        parameters: &[f64],
4955        world: &SimpleCommunicator,
4956    ) -> LadduResult<Vec<DVector<Complex64>>> {
4957        let local_evaluation = self.evaluate_gradient_local(parameters)?;
4958        let n_events = self.dataset.n_events();
4959        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
4960        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
4961        {
4962            // NOTE: gather is required here because the public MPI API returns full per-event gradients.
4963            // Do not replace with all-reduce unless semantics change to aggregate-only outputs.
4964            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4965            world.all_gather_varcount_into(
4966                &local_evaluation
4967                    .iter()
4968                    .flat_map(|v| v.data.as_vec())
4969                    .copied()
4970                    .collect::<Vec<_>>(),
4971                &mut partitioned_buffer,
4972            );
4973        }
4974        Ok(buffer
4975            .chunks(parameters.len())
4976            .map(DVector::from_row_slice)
4977            .collect())
4978    }
4979
4980    #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
4981    fn evaluate_gradient_mpi_with_ctx(
4982        &self,
4983        parameters: &[f64],
4984        world: &SimpleCommunicator,
4985        execution_context: &ExecutionContext,
4986    ) -> Vec<DVector<Complex64>> {
4987        let local_evaluation = self.evaluate_gradient_local_with_ctx(parameters, execution_context);
4988        let n_events = self.dataset.n_events();
4989        let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
4990        let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
4991        {
4992            // NOTE: gather is required here because the public MPI API returns full per-event gradients.
4993            // Do not replace with all-reduce unless semantics change to aggregate-only outputs.
4994            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4995            world.all_gather_varcount_into(
4996                &local_evaluation
4997                    .iter()
4998                    .flat_map(|v| v.data.as_vec())
4999                    .copied()
5000                    .collect::<Vec<_>>(),
5001                &mut partitioned_buffer,
5002            );
5003        }
5004        buffer
5005            .chunks(parameters.len())
5006            .map(DVector::from_row_slice)
5007            .collect()
5008    }
5009
5010    /// Evaluate the gradient of the stored [`Expression`] over the events in the [`Dataset`] stored by the
5011    /// [`Evaluator`] with the given values for free parameters.
5012    pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<Vec<DVector<Complex64>>> {
5013        #[cfg(feature = "mpi")]
5014        {
5015            if let Some(world) = crate::mpi::get_world() {
5016                return self.evaluate_gradient_mpi(parameters, &world);
5017            }
5018        }
5019        self.evaluate_gradient_local(parameters)
5020    }
5021
5022    /// Evaluate the gradient with a reusable execution context.
5023    ///
5024    /// This is intended for repeated calls with the same context instance.
5025    /// Thread behavior follows [`ThreadPolicy`](crate::ThreadPolicy) configured on
5026    /// [`ExecutionContext`](crate::ExecutionContext).
5027    #[cfg(feature = "execution-context-prototype")]
5028    pub fn evaluate_gradient_with_ctx(
5029        &self,
5030        parameters: &[f64],
5031        execution_context: &ExecutionContext,
5032    ) -> Vec<DVector<Complex64>> {
5033        #[cfg(feature = "mpi")]
5034        {
5035            if let Some(world) = crate::mpi::get_world() {
5036                return self.evaluate_gradient_mpi_with_ctx(parameters, &world, execution_context);
5037            }
5038        }
5039        self.evaluate_gradient_local_with_ctx(parameters, execution_context)
5040    }
5041
5042    /// See [`Evaluator::evaluate_gradient_local`]. This method evaluates over a subset
5043    /// of events rather than all events in the total dataset.
5044    pub fn evaluate_gradient_batch_local(
5045        &self,
5046        parameters: &[f64],
5047        indices: &[usize],
5048    ) -> LadduResult<Vec<DVector<Complex64>>> {
5049        let resources = self.resources.read();
5050        let parameters = resources.parameter_map.assemble(parameters)?;
5051        let amplitude_len = self.amplitudes.len();
5052        let grad_dim = parameters.len();
5053        let active_indices = resources.active_indices().to_vec();
5054        let lowered_runtime = self.lowered_runtime();
5055        let gradient_program = lowered_runtime.gradient_program();
5056        let slot_count = self.expression_gradient_slot_count();
5057        #[cfg(feature = "rayon")]
5058        {
5059            Ok(indices
5060                .par_iter()
5061                .map_init(
5062                    || {
5063                        (
5064                            vec![Complex64::ZERO; amplitude_len],
5065                            vec![DVector::zeros(grad_dim); amplitude_len],
5066                            vec![Complex64::ZERO; slot_count],
5067                            vec![Complex64::ZERO; slot_count * grad_dim],
5068                        )
5069                    },
5070                    |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
5071                        let cache = &resources.caches[idx];
5072                        self.fill_amplitude_values_and_gradients(
5073                            amplitude_values,
5074                            gradient_values,
5075                            &active_indices,
5076                            &resources.active,
5077                            &parameters,
5078                            cache,
5079                        );
5080                        gradient_program.evaluate_gradient_into_flat(
5081                            amplitude_values,
5082                            gradient_values,
5083                            value_slots,
5084                            gradient_slots,
5085                            grad_dim,
5086                        )
5087                    },
5088                )
5089                .collect())
5090        }
5091        #[cfg(not(feature = "rayon"))]
5092        {
5093            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
5094            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
5095            let mut value_slots = vec![Complex64::ZERO; slot_count];
5096            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
5097            Ok(indices
5098                .iter()
5099                .map(|&idx| {
5100                    let cache = &resources.caches[idx];
5101                    self.fill_amplitude_values_and_gradients(
5102                        &mut amplitude_values,
5103                        &mut gradient_values,
5104                        &active_indices,
5105                        &resources.active,
5106                        &parameters,
5107                        cache,
5108                    );
5109                    gradient_program.evaluate_gradient_into_flat(
5110                        &amplitude_values,
5111                        &gradient_values,
5112                        &mut value_slots,
5113                        &mut gradient_slots,
5114                        grad_dim,
5115                    )
5116                })
5117                .collect())
5118        }
5119    }
5120
5121    /// See [`Evaluator::evaluate_gradient_mpi`]. This method evaluates over a subset
5122    /// of events rather than all events in the total dataset.
5123    #[cfg(feature = "mpi")]
5124    fn evaluate_gradient_batch_mpi(
5125        &self,
5126        parameters: &[f64],
5127        indices: &[usize],
5128        world: &SimpleCommunicator,
5129    ) -> LadduResult<Vec<DVector<Complex64>>> {
5130        let total = self.dataset.n_events();
5131        let locals = world.locals_from_globals(indices, total);
5132        let flattened_local_evaluation = self
5133            .evaluate_gradient_batch_local(parameters, &locals)?
5134            .iter()
5135            .flat_map(|g| g.data.as_vec().to_vec())
5136            .collect::<Vec<Complex64>>();
5137        Ok(world
5138            .all_gather_batched_partitioned(
5139                &flattened_local_evaluation,
5140                indices,
5141                total,
5142                Some(parameters.len()),
5143            )
5144            .chunks(parameters.len())
5145            .map(DVector::from_row_slice)
5146            .collect())
5147    }
5148
5149    /// Evaluate the gradient of the stored [`Expression`] over a subset of the
5150    /// events in the [`Dataset`] stored by the [`Evaluator`] with the given values
5151    /// for free parameters. See also [`Evaluator::evaluate_gradient`].
5152    pub fn evaluate_gradient_batch(
5153        &self,
5154        parameters: &[f64],
5155        indices: &[usize],
5156    ) -> LadduResult<Vec<DVector<Complex64>>> {
5157        #[cfg(feature = "mpi")]
5158        {
5159            if let Some(world) = crate::mpi::get_world() {
5160                return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
5161            }
5162        }
5163        self.evaluate_gradient_batch_local(parameters, indices)
5164    }
5165
5166    /// Evaluate the stored expression and its gradient over local events in one fused pass.
5167    pub fn evaluate_with_gradient_local(
5168        &self,
5169        parameters: &[f64],
5170    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
5171        let resources = self.resources.read();
5172        let parameters = resources.parameter_map.assemble(parameters)?;
5173        let amplitude_len = self.amplitudes.len();
5174        let grad_dim = parameters.len();
5175        let active_indices = resources.active_indices().to_vec();
5176        let lowered_runtime = self.lowered_runtime();
5177        let value_gradient_program = lowered_runtime.value_gradient_program();
5178        let slot_count = self.expression_value_gradient_slot_count();
5179        #[cfg(feature = "rayon")]
5180        {
5181            Ok(resources
5182                .caches
5183                .par_iter()
5184                .map_init(
5185                    || {
5186                        (
5187                            vec![Complex64::ZERO; amplitude_len],
5188                            vec![DVector::zeros(grad_dim); amplitude_len],
5189                            vec![Complex64::ZERO; slot_count],
5190                            vec![Complex64::ZERO; slot_count * grad_dim],
5191                        )
5192                    },
5193                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
5194                        self.fill_amplitude_values_and_gradients(
5195                            amplitude_values,
5196                            gradient_values,
5197                            &active_indices,
5198                            &resources.active,
5199                            &parameters,
5200                            cache,
5201                        );
5202                        value_gradient_program.evaluate_value_gradient_into_flat(
5203                            amplitude_values,
5204                            gradient_values,
5205                            value_slots,
5206                            gradient_slots,
5207                            grad_dim,
5208                        )
5209                    },
5210                )
5211                .collect())
5212        }
5213        #[cfg(not(feature = "rayon"))]
5214        {
5215            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
5216            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
5217            let mut value_slots = vec![Complex64::ZERO; slot_count];
5218            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
5219            Ok(resources
5220                .caches
5221                .iter()
5222                .map(|cache| {
5223                    self.fill_amplitude_values_and_gradients(
5224                        &mut amplitude_values,
5225                        &mut gradient_values,
5226                        &active_indices,
5227                        &resources.active,
5228                        &parameters,
5229                        cache,
5230                    );
5231                    value_gradient_program.evaluate_value_gradient_into_flat(
5232                        &amplitude_values,
5233                        &gradient_values,
5234                        &mut value_slots,
5235                        &mut gradient_slots,
5236                        grad_dim,
5237                    )
5238                })
5239                .collect())
5240        }
5241    }
5242
5243    /// Evaluate local events and gradients with an explicit active-amplitude mask without mutating evaluator state.
5244    pub fn evaluate_with_gradient_local_with_active_mask(
5245        &self,
5246        parameters: &[f64],
5247        active_mask: &[bool],
5248    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
5249        let resources = self.resources.read();
5250        if active_mask.len() != resources.active.len() {
5251            return Err(LadduError::LengthMismatch {
5252                context: "active amplitude mask".to_string(),
5253                expected: resources.active.len(),
5254                actual: active_mask.len(),
5255            });
5256        }
5257        let parameters = resources.parameter_map.assemble(parameters)?;
5258        let amplitude_len = self.amplitudes.len();
5259        let grad_dim = parameters.len();
5260        let active_indices = active_mask
5261            .iter()
5262            .enumerate()
5263            .filter_map(|(index, &active)| if active { Some(index) } else { None })
5264            .collect::<Vec<_>>();
5265        let lowered_runtime = self.lower_expression_runtime_for_active_mask(active_mask)?;
5266        let slot_count = lowered_runtime.value_gradient_program().scratch_slots();
5267        #[cfg(feature = "rayon")]
5268        {
5269            Ok(resources
5270                .caches
5271                .par_iter()
5272                .map_init(
5273                    || {
5274                        (
5275                            vec![Complex64::ZERO; amplitude_len],
5276                            vec![DVector::zeros(grad_dim); amplitude_len],
5277                            vec![Complex64::ZERO; slot_count],
5278                            vec![Complex64::ZERO; slot_count * grad_dim],
5279                        )
5280                    },
5281                    |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
5282                        self.fill_amplitude_values_and_gradients(
5283                            amplitude_values,
5284                            gradient_values,
5285                            &active_indices,
5286                            active_mask,
5287                            &parameters,
5288                            cache,
5289                        );
5290                        lowered_runtime
5291                            .value_gradient_program()
5292                            .evaluate_value_gradient_into_flat(
5293                                amplitude_values,
5294                                gradient_values,
5295                                value_slots,
5296                                gradient_slots,
5297                                grad_dim,
5298                            )
5299                    },
5300                )
5301                .collect())
5302        }
5303        #[cfg(not(feature = "rayon"))]
5304        {
5305            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
5306            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
5307            let mut value_slots = vec![Complex64::ZERO; slot_count];
5308            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
5309            Ok(resources
5310                .caches
5311                .iter()
5312                .map(|cache| {
5313                    self.fill_amplitude_values_and_gradients(
5314                        &mut amplitude_values,
5315                        &mut gradient_values,
5316                        &active_indices,
5317                        active_mask,
5318                        &parameters,
5319                        cache,
5320                    );
5321                    lowered_runtime
5322                        .value_gradient_program()
5323                        .evaluate_value_gradient_into_flat(
5324                            &amplitude_values,
5325                            &gradient_values,
5326                            &mut value_slots,
5327                            &mut gradient_slots,
5328                            grad_dim,
5329                        )
5330                })
5331                .collect())
5332        }
5333    }
5334
5335    /// Evaluate the stored expression and its gradient over a local subset of events in one fused pass.
5336    pub fn evaluate_with_gradient_batch_local(
5337        &self,
5338        parameters: &[f64],
5339        indices: &[usize],
5340    ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
5341        let resources = self.resources.read();
5342        let parameters = resources.parameter_map.assemble(parameters)?;
5343        let amplitude_len = self.amplitudes.len();
5344        let grad_dim = parameters.len();
5345        let active_indices = resources.active_indices().to_vec();
5346        let lowered_runtime = self.lowered_runtime();
5347        let value_gradient_program = lowered_runtime.value_gradient_program();
5348        let slot_count = self.expression_value_gradient_slot_count();
5349        #[cfg(feature = "rayon")]
5350        {
5351            Ok(indices
5352                .par_iter()
5353                .map_init(
5354                    || {
5355                        (
5356                            vec![Complex64::ZERO; amplitude_len],
5357                            vec![DVector::zeros(grad_dim); amplitude_len],
5358                            vec![Complex64::ZERO; slot_count],
5359                            vec![Complex64::ZERO; slot_count * grad_dim],
5360                        )
5361                    },
5362                    |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
5363                        let cache = &resources.caches[idx];
5364                        self.fill_amplitude_values_and_gradients(
5365                            amplitude_values,
5366                            gradient_values,
5367                            &active_indices,
5368                            &resources.active,
5369                            &parameters,
5370                            cache,
5371                        );
5372                        value_gradient_program.evaluate_value_gradient_into_flat(
5373                            amplitude_values,
5374                            gradient_values,
5375                            value_slots,
5376                            gradient_slots,
5377                            grad_dim,
5378                        )
5379                    },
5380                )
5381                .collect())
5382        }
5383        #[cfg(not(feature = "rayon"))]
5384        {
5385            let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
5386            let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
5387            let mut value_slots = vec![Complex64::ZERO; slot_count];
5388            let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
5389            Ok(indices
5390                .iter()
5391                .map(|&idx| {
5392                    let cache = &resources.caches[idx];
5393                    self.fill_amplitude_values_and_gradients(
5394                        &mut amplitude_values,
5395                        &mut gradient_values,
5396                        &active_indices,
5397                        &resources.active,
5398                        &parameters,
5399                        cache,
5400                    );
5401                    value_gradient_program.evaluate_value_gradient_into_flat(
5402                        &amplitude_values,
5403                        &gradient_values,
5404                        &mut value_slots,
5405                        &mut gradient_slots,
5406                        grad_dim,
5407                    )
5408                })
5409                .collect())
5410        }
5411    }
5412}
5413
5414/// A testing [`Amplitude`].
5415#[derive(Clone, Serialize, Deserialize)]
5416pub struct TestAmplitude {
5417    name: String,
5418    re: Parameter,
5419    pid_re: ParameterID,
5420    im: Parameter,
5421    pid_im: ParameterID,
5422    beam_energy: crate::ScalarID,
5423}
5424
5425impl TestAmplitude {
5426    /// Create a new testing [`Amplitude`].
5427    #[allow(clippy::new_ret_no_self)]
5428    pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
5429        Self {
5430            name: name.to_string(),
5431            re,
5432            pid_re: Default::default(),
5433            im,
5434            pid_im: Default::default(),
5435            beam_energy: Default::default(),
5436        }
5437        .into_expression()
5438    }
5439}
5440
5441#[typetag::serde]
5442impl Amplitude for TestAmplitude {
5443    fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
5444        self.pid_re = resources.register_parameter(&self.re)?;
5445        self.pid_im = resources.register_parameter(&self.im)?;
5446        self.beam_energy = resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
5447        resources.register_amplitude(&self.name)
5448    }
5449
5450    fn precompute(&self, event: &NamedEventView<'_>, cache: &mut Cache) {
5451        let beam = event.p4_at(0);
5452        cache.store_scalar(self.beam_energy, beam.e());
5453    }
5454
5455    fn compute(&self, parameters: &Parameters, cache: &Cache) -> Complex64 {
5456        Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
5457            * cache.get_scalar(self.beam_energy)
5458    }
5459
5460    fn compute_gradient(
5461        &self,
5462        parameters: &Parameters,
5463        cache: &Cache,
5464        gradient: &mut DVector<Complex64>,
5465    ) {
5466        let beam_energy = cache.get_scalar(self.beam_energy);
5467        if let Some(ind) = parameters.free_index(self.pid_re) {
5468            gradient[ind] = Complex64::ONE * beam_energy;
5469        }
5470        if let Some(ind) = parameters.free_index(self.pid_im) {
5471            gradient[ind] = Complex64::I * beam_energy;
5472        }
5473    }
5474}
5475
5476#[cfg(test)]
5477mod tests {
5478    use crate::data::{test_dataset, test_event, DatasetMetadata, EventData};
5479
5480    use super::*;
5481    use crate::resources::{Cache, ParameterID, Parameters, Resources, ScalarID};
5482    use crate::utils::vectors::Vec4;
5483    use approx::assert_relative_eq;
5484    #[cfg(feature = "mpi")]
5485    use mpi_test::mpi_test;
5486    use serde::{Deserialize, Serialize};
5487
5488    #[derive(Clone, Serialize, Deserialize)]
5489    pub struct ComplexScalar {
5490        name: String,
5491        re: Parameter,
5492        pid_re: ParameterID,
5493        im: Parameter,
5494        pid_im: ParameterID,
5495    }
5496
5497    impl ComplexScalar {
5498        #[allow(clippy::new_ret_no_self)]
5499        pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
5500            Self {
5501                name: name.to_string(),
5502                re,
5503                pid_re: Default::default(),
5504                im,
5505                pid_im: Default::default(),
5506            }
5507            .into_expression()
5508        }
5509    }
5510
5511    #[typetag::serde]
5512    impl Amplitude for ComplexScalar {
5513        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
5514            self.pid_re = resources.register_parameter(&self.re)?;
5515            self.pid_im = resources.register_parameter(&self.im)?;
5516            resources.register_amplitude(&self.name)
5517        }
5518
5519        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
5520            Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
5521        }
5522
5523        fn compute_gradient(
5524            &self,
5525            parameters: &Parameters,
5526            _cache: &Cache,
5527            gradient: &mut DVector<Complex64>,
5528        ) {
5529            if let Some(ind) = parameters.free_index(self.pid_re) {
5530                gradient[ind] = Complex64::ONE;
5531            }
5532            if let Some(ind) = parameters.free_index(self.pid_im) {
5533                gradient[ind] = Complex64::I;
5534            }
5535        }
5536    }
5537
5538    #[derive(Clone, Serialize, Deserialize)]
5539    pub struct ParameterOnlyScalar {
5540        name: String,
5541        value: Parameter,
5542        pid: ParameterID,
5543    }
5544
5545    impl ParameterOnlyScalar {
5546        #[allow(clippy::new_ret_no_self)]
5547        pub fn new(name: &str, value: Parameter) -> LadduResult<Expression> {
5548            Self {
5549                name: name.to_string(),
5550                value,
5551                pid: Default::default(),
5552            }
5553            .into_expression()
5554        }
5555    }
5556
5557    #[typetag::serde]
5558    impl Amplitude for ParameterOnlyScalar {
5559        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
5560            self.pid = resources.register_parameter(&self.value)?;
5561            resources.register_amplitude(&self.name)
5562        }
5563
5564        fn dependence_hint(&self) -> ExpressionDependence {
5565            ExpressionDependence::ParameterOnly
5566        }
5567
5568        fn real_valued_hint(&self) -> bool {
5569            true
5570        }
5571
5572        fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
5573            Complex64::new(parameters.get(self.pid), 0.0)
5574        }
5575    }
5576
5577    #[derive(Clone, Serialize, Deserialize)]
5578    pub struct CacheOnlyScalar {
5579        name: String,
5580        beam_energy: ScalarID,
5581    }
5582
5583    impl CacheOnlyScalar {
5584        #[allow(clippy::new_ret_no_self)]
5585        pub fn new(name: &str) -> LadduResult<Expression> {
5586            Self {
5587                name: name.to_string(),
5588                beam_energy: Default::default(),
5589            }
5590            .into_expression()
5591        }
5592    }
5593
5594    #[typetag::serde]
5595    impl Amplitude for CacheOnlyScalar {
5596        fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
5597            self.beam_energy =
5598                resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
5599            resources.register_amplitude(&self.name)
5600        }
5601
5602        fn dependence_hint(&self) -> ExpressionDependence {
5603            ExpressionDependence::CacheOnly
5604        }
5605
5606        fn real_valued_hint(&self) -> bool {
5607            true
5608        }
5609
5610        fn precompute(&self, event: &NamedEventView<'_>, cache: &mut Cache) {
5611            cache.store_scalar(self.beam_energy, event.p4_at(0).e());
5612        }
5613
5614        fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
5615            Complex64::new(cache.get_scalar(self.beam_energy), 0.0)
5616        }
5617    }
5618
5619    #[derive(Clone, Copy)]
5620    enum DeterministicFixtureKind {
5621        Separable,
5622        Partial,
5623        NonSeparable,
5624    }
5625
5626    struct DeterministicFixture {
5627        expression: Expression,
5628        dataset: Arc<Dataset>,
5629        parameters: Vec<f64>,
5630    }
5631
5632    const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
5633    const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
5634
5635    fn deterministic_fixture_dataset() -> Arc<Dataset> {
5636        let metadata = Arc::new(DatasetMetadata::default());
5637        let events = vec![
5638            Arc::new(EventData {
5639                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
5640                aux: vec![],
5641                weight: 0.5,
5642            }),
5643            Arc::new(EventData {
5644                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5645                aux: vec![],
5646                weight: -1.25,
5647            }),
5648            Arc::new(EventData {
5649                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5650                aux: vec![],
5651                weight: 2.0,
5652            }),
5653            Arc::new(EventData {
5654                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5655                aux: vec![],
5656                weight: 0.75,
5657            }),
5658        ];
5659        Arc::new(Dataset::new_with_metadata(events, metadata))
5660    }
5661
5662    fn make_deterministic_fixture(kind: DeterministicFixtureKind) -> DeterministicFixture {
5663        let dataset = deterministic_fixture_dataset();
5664        match kind {
5665            DeterministicFixtureKind::Separable => {
5666                let p1 = ParameterOnlyScalar::new("p1", parameter!("p1"))
5667                    .expect("separable p1 should build");
5668                let p2 = ParameterOnlyScalar::new("p2", parameter!("p2"))
5669                    .expect("separable p2 should build");
5670                let c1 = CacheOnlyScalar::new("c1").expect("separable c1 should build");
5671                let c2 = CacheOnlyScalar::new("c2").expect("separable c2 should build");
5672                DeterministicFixture {
5673                    expression: (&p1 * &c1) + &(&p2 * &c2),
5674                    dataset,
5675                    parameters: vec![0.4, -0.3],
5676                }
5677            }
5678            DeterministicFixtureKind::Partial => {
5679                let p =
5680                    ParameterOnlyScalar::new("p", parameter!("p")).expect("partial p should build");
5681                let c = CacheOnlyScalar::new("c").expect("partial c should build");
5682                let m = TestAmplitude::new("m", parameter!("mr"), parameter!("mi"))
5683                    .expect("partial m should build");
5684                DeterministicFixture {
5685                    expression: (&p * &c) + &m,
5686                    dataset,
5687                    parameters: vec![0.55, 0.2, -0.15],
5688                }
5689            }
5690            DeterministicFixtureKind::NonSeparable => {
5691                let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i"))
5692                    .expect("non-separable m1 should build");
5693                let m2 = TestAmplitude::new("m2", parameter!("m2r"), parameter!("m2i"))
5694                    .expect("non-separable m2 should build");
5695                DeterministicFixture {
5696                    expression: &m1 * &m2,
5697                    dataset,
5698                    parameters: vec![0.25, -0.4, 0.6, 0.1],
5699                }
5700            }
5701        }
5702    }
5703
5704    fn assert_weighted_sum_matches_eventwise_baseline(fixture: &DeterministicFixture) {
5705        let evaluator = fixture
5706            .expression
5707            .load(&fixture.dataset)
5708            .expect("fixture evaluator should load");
5709        let expected_value = evaluator
5710            .evaluate_local(&fixture.parameters)
5711            .expect("evaluation should succeed")
5712            .iter()
5713            .zip(fixture.dataset.events_local().iter())
5714            .fold(0.0, |accum, (value, event)| {
5715                accum + event.weight() * value.re
5716            });
5717        let expected_gradient = evaluator
5718            .evaluate_gradient_local(&fixture.parameters)
5719            .expect("evaluation should succeed")
5720            .iter()
5721            .zip(fixture.dataset.events_local().iter())
5722            .fold(
5723                DVector::zeros(fixture.parameters.len()),
5724                |mut accum, (gradient, event)| {
5725                    accum += gradient.map(|value| value.re).scale(event.weight());
5726                    accum
5727                },
5728            );
5729        let actual_value = evaluator
5730            .evaluate_weighted_value_sum_local(&fixture.parameters)
5731            .expect("evaluation should succeed");
5732        let actual_gradient = evaluator
5733            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
5734            .expect("evaluation should succeed");
5735        assert_relative_eq!(
5736            actual_value,
5737            expected_value,
5738            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5739            max_relative = DETERMINISTIC_STRICT_REL_TOL
5740        );
5741        for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
5742            assert_relative_eq!(
5743                *actual_item,
5744                *expected_item,
5745                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5746                max_relative = DETERMINISTIC_STRICT_REL_TOL
5747            );
5748        }
5749    }
5750    fn assert_mixed_normalization_components_match_combined_path(fixture: &DeterministicFixture) {
5751        let evaluator = fixture
5752            .expression
5753            .load(&fixture.dataset)
5754            .expect("fixture evaluator should load");
5755        let state = {
5756            let resources = evaluator.resources.read();
5757            evaluator.ensure_cached_integral_cache_state(&resources)
5758        }
5759        .expect("state should be available");
5760        assert!(
5761            !state.values.is_empty(),
5762            "fixture should exercise cached normalization terms"
5763        );
5764        assert!(
5765            !state.execution_sets.residual_amplitudes.is_empty(),
5766            "fixture should exercise residual normalization amplitudes"
5767        );
5768
5769        let (residual_value_sum, cached_value_sum) = evaluator
5770            .evaluate_weighted_value_sum_local_components(&fixture.parameters)
5771            .expect("evaluation should succeed");
5772        assert!(residual_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
5773        assert!(cached_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
5774        let combined_value = evaluator
5775            .evaluate_weighted_value_sum_local(&fixture.parameters)
5776            .expect("evaluation should succeed");
5777        assert_relative_eq!(
5778            residual_value_sum + cached_value_sum,
5779            combined_value,
5780            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5781            max_relative = DETERMINISTIC_STRICT_REL_TOL
5782        );
5783
5784        let (residual_gradient_sum, cached_gradient_sum) = evaluator
5785            .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
5786            .expect("evaluation should succeed");
5787        let combined_gradient = evaluator
5788            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
5789            .expect("evaluation should succeed");
5790        assert!(residual_gradient_sum
5791            .iter()
5792            .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
5793        assert!(cached_gradient_sum
5794            .iter()
5795            .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
5796        for ((residual_item, cached_item), combined_item) in residual_gradient_sum
5797            .iter()
5798            .zip(cached_gradient_sum.iter())
5799            .zip(combined_gradient.iter())
5800        {
5801            assert_relative_eq!(
5802                residual_item + cached_item,
5803                *combined_item,
5804                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5805                max_relative = DETERMINISTIC_STRICT_REL_TOL
5806            );
5807        }
5808    }
5809
5810    #[test]
5811    fn test_deterministic_fixture_weighted_sums_stable_across_activation_mask_toggle() {
5812        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5813        let evaluator = fixture
5814            .expression
5815            .load(&fixture.dataset)
5816            .expect("fixture evaluator should load");
5817        let original_mask = evaluator.active_mask();
5818
5819        let original_value = evaluator
5820            .evaluate_weighted_value_sum_local(&fixture.parameters)
5821            .expect("evaluation should succeed");
5822
5823        evaluator.isolate_many(&["p", "c"]);
5824        assert_ne!(evaluator.active_mask(), original_mask);
5825
5826        evaluator
5827            .set_active_mask(&original_mask)
5828            .expect("original fixture active mask should restore");
5829        assert_eq!(evaluator.active_mask(), original_mask);
5830        let actual_value = evaluator
5831            .evaluate_weighted_value_sum_local(&fixture.parameters)
5832            .expect("evaluation should succeed");
5833        assert_relative_eq!(
5834            actual_value,
5835            original_value,
5836            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5837            max_relative = DETERMINISTIC_STRICT_REL_TOL
5838        );
5839    }
5840
5841    #[test]
5842    fn test_deterministic_fixtures_match_eventwise_weighted_sums() {
5843        let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
5844        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5845        let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
5846
5847        assert_weighted_sum_matches_eventwise_baseline(&separable);
5848        assert_weighted_sum_matches_eventwise_baseline(&partial);
5849        assert_weighted_sum_matches_eventwise_baseline(&non_separable);
5850    }
5851    #[test]
5852    fn test_deterministic_fixtures_cover_separable_partial_non_separable_models() {
5853        let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
5854        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5855        let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
5856
5857        let separable_evaluator = separable
5858            .expression
5859            .load(&separable.dataset)
5860            .expect("separable evaluator should load");
5861        let partial_evaluator = partial
5862            .expression
5863            .load(&partial.dataset)
5864            .expect("partial evaluator should load");
5865        let non_separable_evaluator = non_separable
5866            .expression
5867            .load(&non_separable.dataset)
5868            .expect("non-separable evaluator should load");
5869
5870        assert_eq!(
5871            separable_evaluator
5872                .expression_precomputed_cached_integrals()
5873                .expect("integrals should be computed")
5874                .len(),
5875            2
5876        );
5877        assert_eq!(
5878            partial_evaluator
5879                .expression_precomputed_cached_integrals()
5880                .expect("integrals should be computed")
5881                .len(),
5882            1
5883        );
5884        assert!(non_separable_evaluator
5885            .expression_precomputed_cached_integrals()
5886            .expect("integrals should be computed")
5887            .is_empty());
5888    }
5889    #[test]
5890    fn test_partial_fixture_combined_normalization_components_match_total() {
5891        let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5892        assert_mixed_normalization_components_match_combined_path(&partial);
5893    }
5894    #[test]
5895    fn test_non_separable_fixture_normalization_components_stay_residual_only() {
5896        let fixture = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
5897        let evaluator = fixture
5898            .expression
5899            .load(&fixture.dataset)
5900            .expect("fixture evaluator should load");
5901        let resources = evaluator.resources.read();
5902        let state = evaluator
5903            .ensure_cached_integral_cache_state(&resources)
5904            .expect("state should be available");
5905        assert!(state.values.is_empty());
5906
5907        let (residual_value_sum, cached_value_sum) = evaluator
5908            .evaluate_weighted_value_sum_local_components(&fixture.parameters)
5909            .expect("evaluation should succeed");
5910        assert_relative_eq!(
5911            cached_value_sum,
5912            0.0,
5913            epsilon = DETERMINISTIC_STRICT_ABS_TOL
5914        );
5915        assert_relative_eq!(
5916            residual_value_sum,
5917            evaluator
5918                .evaluate_weighted_value_sum_local(&fixture.parameters)
5919                .expect("evaluation should succeed"),
5920            epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5921            max_relative = DETERMINISTIC_STRICT_REL_TOL
5922        );
5923
5924        let (residual_gradient_sum, cached_gradient_sum) = evaluator
5925            .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
5926            .expect("evaluation should succeed");
5927        assert!(cached_gradient_sum
5928            .iter()
5929            .all(|value| value.abs() <= DETERMINISTIC_STRICT_ABS_TOL));
5930        let combined_gradient = evaluator
5931            .evaluate_weighted_gradient_sum_local(&fixture.parameters)
5932            .expect("evaluation should succeed");
5933        for (residual_item, combined_item) in
5934            residual_gradient_sum.iter().zip(combined_gradient.iter())
5935        {
5936            assert_relative_eq!(
5937                *residual_item,
5938                *combined_item,
5939                epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5940                max_relative = DETERMINISTIC_STRICT_REL_TOL
5941            );
5942        }
5943    }
5944
5945    #[test]
5946    fn test_batch_evaluation() {
5947        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag")).unwrap();
5948        let mut event1 = test_event();
5949        event1.p4s[0].t = 10.0;
5950        let mut event2 = test_event();
5951        event2.p4s[0].t = 11.0;
5952        let mut event3 = test_event();
5953        event3.p4s[0].t = 12.0;
5954        let dataset = Arc::new(Dataset::new_with_metadata(
5955            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5956            Arc::new(DatasetMetadata::default()),
5957        ));
5958        let evaluator = expr.load(&dataset).unwrap();
5959        let result = evaluator
5960            .evaluate_batch(&[1.1, 2.2], &[0, 2])
5961            .expect("evaluation should succeed");
5962        assert_eq!(result.len(), 2);
5963        assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
5964        assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
5965        let result_grad = evaluator
5966            .evaluate_gradient_batch(&[1.1, 2.2], &[0, 2])
5967            .expect("evaluation should succeed");
5968        assert_eq!(result_grad.len(), 2);
5969        assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
5970        assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
5971        assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
5972        assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
5973    }
5974
5975    #[test]
5976    fn test_load_compiles_expression_ir_once() {
5977        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5978            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5979        .norm_sqr();
5980        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5981        let evaluator = expr.load(&dataset).unwrap();
5982        assert!(evaluator.expression_slot_count() > 0);
5983    }
5984    #[test]
5985    fn test_expression_ir_value_matches_lowered_runtime() {
5986        let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5987            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5988            * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
5989        .conj()
5990        .norm_sqr();
5991        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5992        let evaluator = expr.load(&dataset).unwrap();
5993        let resources = evaluator.resources.read();
5994        let parameters = resources
5995            .parameter_map
5996            .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
5997            .expect("parameters should assemble");
5998        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5999        evaluator.fill_amplitude_values(
6000            &mut amplitude_values,
6001            resources.active_indices(),
6002            &parameters,
6003            &resources.caches[0],
6004        );
6005        let mut ir_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
6006        let lowered_runtime = evaluator.lowered_runtime();
6007        let lowered_program = lowered_runtime.value_program();
6008        let mut lowered_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
6009        let lowered_value =
6010            evaluator.evaluate_expression_value_with_scratch(&amplitude_values, &mut ir_slots);
6011        let direct_lowered_value =
6012            lowered_program.evaluate_into(&amplitude_values, &mut lowered_slots);
6013        let ir_value = evaluator
6014            .expression_ir()
6015            .evaluate_into(&amplitude_values, &mut ir_slots);
6016        assert_relative_eq!(lowered_value.re, direct_lowered_value.re);
6017        assert_relative_eq!(lowered_value.im, direct_lowered_value.im);
6018        assert_relative_eq!(lowered_value.re, ir_value.re);
6019        assert_relative_eq!(lowered_value.im, ir_value.im);
6020    }
6021    #[test]
6022    fn test_expression_ir_load_initializes_with_lowered_value_runtime() {
6023        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai"))
6024            .unwrap()
6025            .norm_sqr();
6026        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6027        let evaluator = expr.load(&dataset).unwrap();
6028        let lowered_runtime = evaluator.lowered_runtime();
6029        assert_eq!(
6030            lowered_runtime.value_program().kind(),
6031            lowered::LoweredProgramKind::Value
6032        );
6033        assert_eq!(
6034            lowered_runtime.gradient_program().kind(),
6035            lowered::LoweredProgramKind::Gradient
6036        );
6037        assert_eq!(
6038            lowered_runtime.value_gradient_program().kind(),
6039            lowered::LoweredProgramKind::ValueGradient
6040        );
6041    }
6042    #[test]
6043    fn test_expression_ir_gradient_matches_lowered_runtime() {
6044        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
6045            * TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
6046        .norm_sqr();
6047        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6048        let evaluator = expr.load(&dataset).unwrap();
6049        let resources = evaluator.resources.read();
6050        let parameters = resources
6051            .parameter_map
6052            .assemble(&[1.0, 0.25, -0.8, 0.5])
6053            .expect("parameters should assemble");
6054        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
6055        evaluator.fill_amplitude_values(
6056            &mut amplitude_values,
6057            resources.active_indices(),
6058            &parameters,
6059            &resources.caches[0],
6060        );
6061        let mut active_mask = vec![false; evaluator.amplitudes.len()];
6062        for &index in resources.active_indices() {
6063            active_mask[index] = true;
6064        }
6065        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
6066            .map(|_| DVector::zeros(parameters.len()))
6067            .collect::<Vec<_>>();
6068        evaluator.fill_amplitude_gradients(
6069            &mut amplitude_gradients,
6070            &active_mask,
6071            &parameters,
6072            &resources.caches[0],
6073        );
6074        let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
6075        let mut ir_gradient_slots: Vec<DVector<Complex64>> =
6076            (0..evaluator.expression_ir().node_count())
6077                .map(|_| DVector::zeros(parameters.len()))
6078                .collect();
6079        let lowered_runtime = evaluator.lowered_runtime();
6080        let lowered_program = lowered_runtime.gradient_program();
6081        let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
6082        let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
6083            .scratch_slots())
6084            .map(|_| DVector::zeros(parameters.len()))
6085            .collect();
6086        let active_gradient = evaluator.evaluate_expression_gradient_with_scratch(
6087            &amplitude_values,
6088            &amplitude_gradients,
6089            &mut ir_value_slots,
6090            &mut ir_gradient_slots,
6091        );
6092        let ir_gradient = evaluator.expression_ir().evaluate_gradient_into(
6093            &amplitude_values,
6094            &amplitude_gradients,
6095            &mut ir_value_slots,
6096            &mut ir_gradient_slots,
6097        );
6098        let lowered_gradient = lowered_program.evaluate_gradient_into(
6099            &amplitude_values,
6100            &amplitude_gradients,
6101            &mut lowered_value_slots,
6102            &mut lowered_gradient_slots,
6103        );
6104        for (active, lowered) in active_gradient.iter().zip(lowered_gradient.iter()) {
6105            assert_relative_eq!(active.re, lowered.re);
6106            assert_relative_eq!(active.im, lowered.im);
6107        }
6108        for (lowered, ir) in lowered_gradient.iter().zip(ir_gradient.iter()) {
6109            assert_relative_eq!(lowered.re, ir.re);
6110            assert_relative_eq!(lowered.im, ir.im);
6111        }
6112    }
6113    #[test]
6114    fn test_expression_ir_value_gradient_matches_lowered_runtime() {
6115        let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
6116            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
6117            * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
6118        .norm_sqr();
6119        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6120        let evaluator = expr.load(&dataset).unwrap();
6121        let resources = evaluator.resources.read();
6122        let parameters = resources
6123            .parameter_map
6124            .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
6125            .expect("parameters should assemble");
6126        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
6127        evaluator.fill_amplitude_values(
6128            &mut amplitude_values,
6129            resources.active_indices(),
6130            &parameters,
6131            &resources.caches[0],
6132        );
6133        let mut active_mask = vec![false; evaluator.amplitudes.len()];
6134        for &index in resources.active_indices() {
6135            active_mask[index] = true;
6136        }
6137        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
6138            .map(|_| DVector::zeros(parameters.len()))
6139            .collect::<Vec<_>>();
6140        evaluator.fill_amplitude_gradients(
6141            &mut amplitude_gradients,
6142            &active_mask,
6143            &parameters,
6144            &resources.caches[0],
6145        );
6146        let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
6147        let mut ir_gradient_slots: Vec<DVector<Complex64>> =
6148            (0..evaluator.expression_ir().node_count())
6149                .map(|_| DVector::zeros(parameters.len()))
6150                .collect();
6151        let lowered_runtime = evaluator.lowered_runtime();
6152        let lowered_program = lowered_runtime.value_gradient_program();
6153        let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
6154        let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
6155            .scratch_slots())
6156            .map(|_| DVector::zeros(parameters.len()))
6157            .collect();
6158
6159        let active_value_gradient = evaluator.evaluate_expression_value_gradient_with_scratch(
6160            &amplitude_values,
6161            &amplitude_gradients,
6162            &mut ir_value_slots,
6163            &mut ir_gradient_slots,
6164        );
6165        let ir_value_gradient = evaluator.expression_ir().evaluate_value_gradient_into(
6166            &amplitude_values,
6167            &amplitude_gradients,
6168            &mut ir_value_slots,
6169            &mut ir_gradient_slots,
6170        );
6171        let lowered_value_gradient = lowered_program.evaluate_value_gradient_into(
6172            &amplitude_values,
6173            &amplitude_gradients,
6174            &mut lowered_value_slots,
6175            &mut lowered_gradient_slots,
6176        );
6177
6178        assert_relative_eq!(active_value_gradient.0.re, lowered_value_gradient.0.re);
6179        assert_relative_eq!(active_value_gradient.0.im, lowered_value_gradient.0.im);
6180        for (active, lowered) in active_value_gradient
6181            .1
6182            .iter()
6183            .zip(lowered_value_gradient.1.iter())
6184        {
6185            assert_relative_eq!(active.re, lowered.re);
6186            assert_relative_eq!(active.im, lowered.im);
6187        }
6188        assert_relative_eq!(lowered_value_gradient.0.re, ir_value_gradient.0.re);
6189        assert_relative_eq!(lowered_value_gradient.0.im, ir_value_gradient.0.im);
6190        for (lowered, ir) in lowered_value_gradient
6191            .1
6192            .iter()
6193            .zip(ir_value_gradient.1.iter())
6194        {
6195            assert_relative_eq!(lowered.re, ir.re);
6196            assert_relative_eq!(lowered.im, ir.im);
6197        }
6198    }
6199    #[test]
6200    fn test_expression_runtime_diagnostics_reports_lowered_programs() {
6201        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
6202        let evaluator = fixture
6203            .expression
6204            .load(&fixture.dataset)
6205            .expect("fixture evaluator should load");
6206
6207        let diagnostics = evaluator.expression_runtime_diagnostics();
6208        assert!(diagnostics.ir_planning_enabled);
6209        assert!(diagnostics.lowered_value_program_present);
6210        assert!(diagnostics.lowered_gradient_program_present);
6211        assert!(diagnostics.lowered_value_gradient_program_present);
6212        assert!(diagnostics.residual_runtime_present);
6213        assert_eq!(
6214            diagnostics.specialization_status,
6215            Some(ExpressionSpecializationStatus {
6216                origin: ExpressionSpecializationOrigin::InitialLoad,
6217            })
6218        );
6219    }
6220    #[test]
6221    fn test_expression_runtime_diagnostics_reports_specialization_origin() {
6222        let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
6223        let evaluator = fixture
6224            .expression
6225            .load(&fixture.dataset)
6226            .expect("fixture evaluator should load");
6227
6228        assert_eq!(
6229            evaluator
6230                .expression_runtime_diagnostics()
6231                .specialization_status,
6232            Some(ExpressionSpecializationStatus {
6233                origin: ExpressionSpecializationOrigin::InitialLoad,
6234            })
6235        );
6236
6237        evaluator.isolate_many(&["p"]);
6238        assert_eq!(
6239            evaluator
6240                .expression_runtime_diagnostics()
6241                .specialization_status,
6242            Some(ExpressionSpecializationStatus {
6243                origin: ExpressionSpecializationOrigin::CacheMissRebuild,
6244            })
6245        );
6246    }
6247    #[test]
6248    fn test_compiled_expression_display_reports_dag_refs() {
6249        let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
6250        let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
6251        let term = &a * &b;
6252        let expr = &term + &term;
6253        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6254        let evaluator = expr.load(&dataset).unwrap();
6255
6256        let compiled = evaluator.compiled_expression();
6257        let display = compiled.to_string();
6258
6259        assert_eq!(compiled.root(), compiled.nodes().len() - 1);
6260        assert!(display.contains("#"));
6261        assert!(display.contains("+"));
6262        assert!(display.contains("×"));
6263        assert!(display.contains("a(id=0)"));
6264        assert!(display.contains("b(id=1)"));
6265        assert!(display.contains("(ref)"));
6266    }
6267
6268    #[test]
6269    fn test_expression_compiled_expression_display_reports_dag_refs_without_loading() {
6270        let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
6271        let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
6272        let term = &a * &b;
6273        let expr = &term + &term;
6274
6275        let compiled = expr.compiled_expression();
6276        let display = compiled.to_string();
6277
6278        assert_eq!(compiled.root(), compiled.nodes().len() - 1);
6279        assert!(display.contains("#"));
6280        assert!(display.contains("+"));
6281        assert!(display.contains("×"));
6282        assert!(display.contains("a(id=0)"));
6283        assert!(display.contains("b(id=1)"));
6284        assert!(display.contains("(ref)"));
6285    }
6286
6287    #[test]
6288    fn test_compiled_expression_display_uses_current_active_mask() {
6289        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
6290            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
6291        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6292        let evaluator = expr.load(&dataset).unwrap();
6293        evaluator.deactivate("b");
6294
6295        let compiled = evaluator.compiled_expression().to_string();
6296
6297        assert!(compiled.contains("a(id=0)"));
6298        assert!(!compiled.contains("b(id=1)"));
6299        assert!(compiled.contains("const 0"));
6300    }
6301
6302    #[test]
6303    fn test_evaluator_expression_reconstructs_expression() {
6304        let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
6305        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6306        let evaluator = expr.load(&dataset).unwrap();
6307
6308        assert_eq!(
6309            evaluator.expression().compiled_expression(),
6310            expr.compiled_expression()
6311        );
6312    }
6313
6314    #[test]
6315    fn test_active_mask_override_ignores_current_ir_specialization() {
6316        let expr = ComplexScalar::new("amp", parameter!("scale"), parameter!("amp_im", 0.0))
6317            .unwrap()
6318            .norm_sqr();
6319        let dataset = Arc::new(test_dataset());
6320        let evaluator = expr.load(&dataset).unwrap();
6321        let params = vec![2.0];
6322
6323        evaluator.deactivate("amp");
6324        assert_eq!(
6325            evaluator
6326                .evaluate(&params)
6327                .expect("evaluation should succeed")[0],
6328            Complex64::new(0.0, 0.0)
6329        );
6330
6331        let overridden = evaluator
6332            .evaluate_local_with_active_mask(&params, &[true])
6333            .unwrap();
6334        assert_eq!(overridden[0], Complex64::new(4.0, 0.0));
6335
6336        let overridden_fused = evaluator
6337            .evaluate_with_gradient_local_with_active_mask(&params, &[true])
6338            .unwrap();
6339        assert_eq!(overridden_fused[0].0, Complex64::new(4.0, 0.0));
6340        assert_eq!(overridden_fused[0].1[0], Complex64::new(4.0, 0.0));
6341    }
6342    #[test]
6343    fn test_expression_ir_dependence_diagnostics_surface() {
6344        let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
6345            + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
6346        .norm_sqr();
6347        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6348        let evaluator = expr.load(&dataset).unwrap();
6349        let annotations = evaluator
6350            .expression_node_dependence_annotations()
6351            .expect("annotations should exist");
6352        assert_eq!(annotations.len(), evaluator.expression_ir().node_count());
6353        assert!(annotations
6354            .iter()
6355            .all(|dependence| *dependence == ExpressionDependence::Mixed));
6356        assert_eq!(
6357            evaluator
6358                .expression_root_dependence()
6359                .expect("root dependence should exist"),
6360            ExpressionDependence::Mixed
6361        );
6362    }
6363    #[test]
6364    fn test_expression_ir_default_dependence_hint_is_mixed() {
6365        let expr = ComplexScalar::new("c", parameter!("cr"), parameter!("ci")).unwrap();
6366        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6367        let evaluator = expr.load(&dataset).unwrap();
6368        assert_eq!(
6369            evaluator
6370                .expression_root_dependence()
6371                .expect("root dependence should exist"),
6372            ExpressionDependence::Mixed
6373        );
6374    }
6375    #[test]
6376    fn test_expression_ir_parameter_only_dependence_hint_propagates() {
6377        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap();
6378        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6379        let evaluator = expr.load(&dataset).unwrap();
6380        assert_eq!(
6381            evaluator
6382                .expression_root_dependence()
6383                .expect("root dependence should exist"),
6384            ExpressionDependence::ParameterOnly
6385        );
6386    }
6387    #[test]
6388    fn test_expression_ir_cache_only_dependence_hint_propagates() {
6389        let expr = CacheOnlyScalar::new("k").unwrap();
6390        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6391        let evaluator = expr.load(&dataset).unwrap();
6392        assert_eq!(
6393            evaluator
6394                .expression_root_dependence()
6395                .expect("root dependence should exist"),
6396            ExpressionDependence::CacheOnly
6397        );
6398    }
6399    #[test]
6400    fn test_expression_ir_real_valued_hint_folds_imag_projection_to_zero() {
6401        let expr = ParameterOnlyScalar::new("p", parameter!("p"))
6402            .unwrap()
6403            .imag();
6404        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6405        let evaluator = expr.load(&dataset).unwrap();
6406        let ir = evaluator.expression_ir();
6407
6408        assert!(matches!(
6409            ir.nodes()[ir.root()],
6410            ir::IrNode::Constant(value) if value == Complex64::ZERO
6411        ));
6412        assert_eq!(
6413            evaluator
6414                .evaluate(&[2.5])
6415                .expect("evaluation should succeed")[0],
6416            Complex64::ZERO
6417        );
6418    }
6419    #[test]
6420    fn test_expression_ir_real_valued_hint_simplifies_conjugation() {
6421        let expr = ParameterOnlyScalar::new("p", parameter!("p"))
6422            .unwrap()
6423            .conj();
6424        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6425        let evaluator = expr.load(&dataset).unwrap();
6426        let ir = evaluator.expression_ir();
6427
6428        assert!(matches!(ir.nodes()[ir.root()], ir::IrNode::Amp(0)));
6429        assert_eq!(
6430            evaluator
6431                .evaluate(&[2.5])
6432                .expect("evaluation should succeed")[0],
6433            Complex64::new(2.5, 0.0)
6434        );
6435    }
6436    #[test]
6437    fn test_expression_ir_dependence_warnings_surface() {
6438        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6439            + &CacheOnlyScalar::new("k").unwrap();
6440        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6441        let evaluator = expr.load(&dataset).unwrap();
6442        assert!(evaluator
6443            .expression_dependence_warnings()
6444            .expect("warnings should exist")
6445            .iter()
6446            .any(|warning| warning.contains("both ParameterOnly and CacheOnly")));
6447    }
6448    #[test]
6449    fn test_expression_ir_normalization_plan_explain_surface() {
6450        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6451            * &CacheOnlyScalar::new("k").unwrap();
6452        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6453        let evaluator = expr.load(&dataset).unwrap();
6454        let explain = evaluator
6455            .expression_normalization_plan_explain()
6456            .expect("plan should exist");
6457        assert_eq!(explain.root_dependence, ExpressionDependence::Mixed);
6458        assert_eq!(explain.separable_mul_candidate_nodes.len(), 1);
6459        assert_eq!(
6460            explain.cached_separable_nodes,
6461            explain.separable_mul_candidate_nodes
6462        );
6463        assert!(explain.residual_terms.iter().all(|index| {
6464            !explain
6465                .separable_mul_candidate_nodes
6466                .iter()
6467                .any(|candidate| candidate == index)
6468        }));
6469    }
6470    #[test]
6471    fn test_expression_ir_normalization_execution_sets_surface() {
6472        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6473            * &CacheOnlyScalar::new("k").unwrap();
6474        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6475        let evaluator = expr.load(&dataset).unwrap();
6476        let sets = evaluator
6477            .expression_normalization_execution_sets()
6478            .expect("sets should exist");
6479        assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
6480        assert_eq!(sets.cached_cache_amplitudes, vec![1]);
6481        assert!(sets.residual_amplitudes.is_empty());
6482    }
6483    #[test]
6484    fn test_expression_ir_normalization_execution_sets_partial_surface() {
6485        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6486            * &CacheOnlyScalar::new("k").unwrap())
6487            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6488        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6489        let evaluator = expr.load(&dataset).unwrap();
6490        let sets = evaluator
6491            .expression_normalization_execution_sets()
6492            .expect("sets should exist");
6493        assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
6494        assert_eq!(sets.cached_cache_amplitudes, vec![1]);
6495        assert_eq!(sets.residual_amplitudes, vec![2]);
6496    }
6497    #[test]
6498    fn test_expression_ir_precomputed_cached_integrals_at_load() {
6499        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6500            * &CacheOnlyScalar::new("k").unwrap();
6501        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6502        let evaluator = expr.load(&dataset).unwrap();
6503        let precomputed = evaluator
6504            .expression_precomputed_cached_integrals()
6505            .expect("integrals should exist");
6506        assert_eq!(precomputed.len(), 1);
6507        let cache_reference = CacheOnlyScalar::new("k_ref")
6508            .unwrap()
6509            .load(&dataset)
6510            .unwrap();
6511        let cache_values = cache_reference
6512            .evaluate_local(&[])
6513            .expect("evaluation should succeed");
6514        let expected_weighted_sum = cache_values
6515            .iter()
6516            .zip(dataset.events_local().iter())
6517            .fold(Complex64::ZERO, |acc, (value, event)| {
6518                acc + (*value * event.weight())
6519            });
6520        assert_relative_eq!(
6521            precomputed[0].weighted_cache_sum.re,
6522            expected_weighted_sum.re
6523        );
6524        assert_relative_eq!(
6525            precomputed[0].weighted_cache_sum.im,
6526            expected_weighted_sum.im
6527        );
6528    }
6529    #[test]
6530    fn test_expression_ir_precomputed_cached_integrals_empty_when_non_separable() {
6531        let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
6532            * &CacheOnlyScalar::new("k").unwrap();
6533        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6534        let evaluator = expr.load(&dataset).unwrap();
6535        assert!(evaluator
6536            .expression_precomputed_cached_integrals()
6537            .expect("integrals should exist")
6538            .is_empty());
6539    }
6540    #[test]
6541    fn test_expression_ir_precomputed_cached_integrals_recompute_on_activation_change() {
6542        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6543            * &CacheOnlyScalar::new("k").unwrap();
6544        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6545        let evaluator = expr.load(&dataset).unwrap();
6546        assert_eq!(
6547            evaluator
6548                .expression_precomputed_cached_integrals()
6549                .expect("integrals should exist")
6550                .len(),
6551            1
6552        );
6553
6554        evaluator.isolate_many(&["p"]);
6555        assert!(evaluator
6556            .expression_precomputed_cached_integrals()
6557            .expect("integrals should exist")
6558            .is_empty());
6559    }
6560    #[test]
6561    fn test_expression_ir_precomputed_cached_integrals_recompute_on_dataset_change() {
6562        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6563            * &CacheOnlyScalar::new("k").unwrap();
6564        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6565        let mut evaluator = expr.load(&dataset).unwrap();
6566        drop(dataset);
6567        let before = evaluator
6568            .expression_precomputed_cached_integrals()
6569            .expect("integrals should exist");
6570        assert_eq!(before.len(), 1);
6571
6572        Arc::get_mut(&mut evaluator.dataset)
6573            .expect("evaluator should own dataset Arc in this test")
6574            .clear_events_local();
6575        let after = evaluator
6576            .expression_precomputed_cached_integrals()
6577            .expect("integrals should exist");
6578        assert_eq!(after.len(), 1);
6579        assert_eq!(after[0].weighted_cache_sum, Complex64::ZERO);
6580        assert!(before[0].weighted_cache_sum != after[0].weighted_cache_sum);
6581    }
6582    #[test]
6583    fn test_expression_ir_precomputed_cached_integral_gradient_terms_scale_by_cache_integrals() {
6584        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6585            * &CacheOnlyScalar::new("k").unwrap();
6586        let dataset = Arc::new(Dataset::new(vec![
6587            Arc::new(test_event()),
6588            Arc::new(test_event()),
6589        ]));
6590        let evaluator = expr.load(&dataset).unwrap();
6591        let cached_integrals = evaluator
6592            .expression_precomputed_cached_integrals()
6593            .expect("integrals should exist");
6594        assert_eq!(cached_integrals.len(), 1);
6595        let gradient_terms = evaluator
6596            .expression_precomputed_cached_integral_gradient_terms(&[1.25])
6597            .expect("evaluation should succeed");
6598        assert_eq!(gradient_terms.len(), 1);
6599        assert_eq!(gradient_terms[0].weighted_gradient.len(), 1);
6600        assert_relative_eq!(
6601            gradient_terms[0].weighted_gradient[0].re,
6602            cached_integrals[0].weighted_cache_sum.re,
6603            epsilon = 1e-6
6604        );
6605        assert_relative_eq!(
6606            gradient_terms[0].weighted_gradient[0].im,
6607            cached_integrals[0].weighted_cache_sum.im,
6608            epsilon = 1e-6
6609        );
6610    }
6611    #[test]
6612    fn test_expression_ir_precomputed_cached_integral_gradient_terms_empty_when_not_separable() {
6613        let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
6614            * &CacheOnlyScalar::new("k").unwrap();
6615        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6616        let evaluator = expr.load(&dataset).unwrap();
6617        assert!(evaluator
6618            .expression_precomputed_cached_integral_gradient_terms(&[0.1, -0.2])
6619            .expect("evaluation should succeed")
6620            .is_empty());
6621    }
6622    #[test]
6623    fn test_expression_ir_lowered_cached_factor_programs_match_ir_cached_paths() {
6624        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6625            * &CacheOnlyScalar::new("k").unwrap())
6626            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6627        let dataset = Arc::new(test_dataset());
6628        let evaluator = expr.load(&dataset).unwrap();
6629        let resources = evaluator.resources.read();
6630        let state = evaluator
6631            .ensure_cached_integral_cache_state(&resources)
6632            .expect("state should be available");
6633        let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
6634        let parameters = resources
6635            .parameter_map
6636            .assemble(&[0.55, 0.2, -0.15])
6637            .expect("parameters should assemble");
6638
6639        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
6640        evaluator.fill_amplitude_values(
6641            &mut amplitude_values,
6642            &state.execution_sets.cached_parameter_amplitudes,
6643            &parameters,
6644            &resources.caches[0],
6645        );
6646        let cached_value_ir =
6647            evaluator.evaluate_cached_weighted_value_sum_ir(&state, &amplitude_values);
6648        let cached_value_lowered = evaluator
6649            .evaluate_cached_weighted_value_sum_lowered(
6650                &state,
6651                lowered_artifacts.as_ref(),
6652                &amplitude_values,
6653            )
6654            .expect("cached value lowering should succeed");
6655        assert_relative_eq!(cached_value_lowered, cached_value_ir, epsilon = 1e-12);
6656
6657        let mut cached_parameter_mask = vec![false; evaluator.amplitudes.len()];
6658        for &index in &state.execution_sets.cached_parameter_amplitudes {
6659            cached_parameter_mask[index] = true;
6660        }
6661        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
6662            .map(|_| DVector::zeros(parameters.len()))
6663            .collect::<Vec<_>>();
6664        evaluator.fill_amplitude_gradients(
6665            &mut amplitude_gradients,
6666            &cached_parameter_mask,
6667            &parameters,
6668            &resources.caches[0],
6669        );
6670        let cached_gradient_ir = evaluator.evaluate_cached_weighted_gradient_sum_ir(
6671            &state,
6672            &amplitude_values,
6673            &amplitude_gradients,
6674            parameters.len(),
6675        );
6676        let cached_gradient_lowered = evaluator
6677            .evaluate_cached_weighted_gradient_sum_lowered(
6678                &state,
6679                lowered_artifacts.as_ref(),
6680                &amplitude_values,
6681                &amplitude_gradients,
6682                parameters.len(),
6683            )
6684            .expect("cached gradient lowering should succeed");
6685        for (lowered, ir) in cached_gradient_lowered
6686            .iter()
6687            .zip(cached_gradient_ir.iter())
6688        {
6689            assert_relative_eq!(*lowered, *ir, epsilon = 1e-12);
6690        }
6691    }
6692    #[test]
6693    fn test_expression_ir_lowered_residual_runtime_matches_zeroed_node_path() {
6694        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6695            * &CacheOnlyScalar::new("k").unwrap())
6696            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6697        let dataset = Arc::new(test_dataset());
6698        let evaluator = expr.load(&dataset).unwrap();
6699        let resources = evaluator.resources.read();
6700        let state = evaluator
6701            .ensure_cached_integral_cache_state(&resources)
6702            .expect("state should be available");
6703        let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
6704        let parameters = resources
6705            .parameter_map
6706            .assemble(&[0.55, 0.2, -0.15])
6707            .expect("parameters should assemble");
6708
6709        let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
6710        evaluator.fill_amplitude_values(
6711            &mut amplitude_values,
6712            &state.execution_sets.residual_amplitudes,
6713            &parameters,
6714            &resources.caches[0],
6715        );
6716        let residual_value_ir = evaluator.evaluate_residual_value_ir(&state, &amplitude_values);
6717        let residual_program = lowered_artifacts
6718            .residual_runtime
6719            .as_ref()
6720            .map(|runtime| runtime.value_program())
6721            .expect("residual value lowering should succeed");
6722        let mut value_slots = vec![Complex64::ZERO; residual_program.scratch_slots()];
6723        let residual_value_lowered =
6724            residual_program.evaluate_into(&amplitude_values, &mut value_slots);
6725        assert_relative_eq!(
6726            residual_value_lowered.re,
6727            residual_value_ir.re,
6728            epsilon = 1e-12
6729        );
6730        assert_relative_eq!(
6731            residual_value_lowered.im,
6732            residual_value_ir.im,
6733            epsilon = 1e-12
6734        );
6735
6736        let mut residual_active_mask = vec![false; evaluator.amplitudes.len()];
6737        for &index in &state.execution_sets.residual_amplitudes {
6738            residual_active_mask[index] = true;
6739        }
6740        let mut amplitude_gradients = (0..evaluator.amplitudes.len())
6741            .map(|_| DVector::zeros(parameters.len()))
6742            .collect::<Vec<_>>();
6743        evaluator.fill_amplitude_gradients(
6744            &mut amplitude_gradients,
6745            &residual_active_mask,
6746            &parameters,
6747            &resources.caches[0],
6748        );
6749        let residual_gradient_ir = evaluator.evaluate_residual_gradient_ir(
6750            &state,
6751            &amplitude_values,
6752            &amplitude_gradients,
6753            parameters.len(),
6754        );
6755
6756        let program = lowered_artifacts
6757            .residual_runtime
6758            .as_ref()
6759            .map(|runtime| runtime.gradient_program())
6760            .expect("gradient lowering should succeed");
6761        let mut value_slots = vec![Complex64::ZERO; program.scratch_slots()];
6762        let mut gradient_slots = vec![Complex64::ZERO; program.scratch_slots() * parameters.len()];
6763        let residual_gradient_lowered = program.evaluate_gradient_into_flat(
6764            &amplitude_values,
6765            &amplitude_gradients,
6766            &mut value_slots,
6767            &mut gradient_slots,
6768            parameters.len(),
6769        );
6770
6771        for (lowered, ir) in residual_gradient_lowered
6772            .iter()
6773            .zip(residual_gradient_ir.iter())
6774        {
6775            assert_relative_eq!(lowered.re, ir.re, epsilon = 1e-12);
6776            assert_relative_eq!(lowered.im, ir.im, epsilon = 1e-12);
6777        }
6778    }
6779    #[test]
6780    fn test_expression_ir_reuses_lowered_artifacts_when_dataset_key_changes() {
6781        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6782            * &CacheOnlyScalar::new("k").unwrap())
6783            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6784        let dataset = Arc::new(test_dataset());
6785        let mut evaluator = expr.load(&dataset).unwrap();
6786        drop(dataset);
6787
6788        assert_eq!(evaluator.specialization_cache_len(), 1);
6789        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
6790
6791        evaluator.reset_expression_compile_metrics();
6792        evaluator.reset_expression_specialization_metrics();
6793
6794        Arc::get_mut(&mut evaluator.dataset)
6795            .expect("evaluator should own dataset Arc in this test")
6796            .clear_events_local();
6797
6798        let cached_integrals = evaluator
6799            .expression_precomputed_cached_integrals()
6800            .expect("integrals should exist");
6801        assert_eq!(cached_integrals.len(), 1);
6802        assert_eq!(cached_integrals[0].weighted_cache_sum, Complex64::ZERO);
6803
6804        assert_eq!(evaluator.specialization_cache_len(), 2);
6805        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
6806        assert_eq!(
6807            evaluator.expression_specialization_metrics(),
6808            ExpressionSpecializationMetrics {
6809                cache_hits: 0,
6810                cache_misses: 1,
6811            }
6812        );
6813
6814        let compile_metrics = evaluator.expression_compile_metrics();
6815        assert_eq!(compile_metrics.specialization_cache_hits, 0);
6816        assert_eq!(compile_metrics.specialization_cache_misses, 1);
6817        assert_eq!(compile_metrics.specialization_lowering_cache_hits, 1);
6818        assert_eq!(compile_metrics.specialization_lowering_cache_misses, 0);
6819        assert!(compile_metrics.specialization_ir_compile_nanos > 0);
6820        assert!(compile_metrics.specialization_cached_integrals_nanos > 0);
6821        assert_eq!(compile_metrics.specialization_lowering_nanos, 0);
6822    }
6823
6824    #[test]
6825    fn test_evaluate_weighted_gradient_sum_local_matches_eventwise_baseline() {
6826        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6827        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6828        let c1 = CacheOnlyScalar::new("c1").unwrap();
6829        let c2 = CacheOnlyScalar::new("c2").unwrap();
6830        let c3 = CacheOnlyScalar::new("c3").unwrap();
6831        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6832        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6833        let dataset = Arc::new(test_dataset());
6834        let evaluator = expr.load(&dataset).unwrap();
6835        assert_eq!(
6836            evaluator
6837                .expression_precomputed_cached_integrals()
6838                .expect("integrals should exist")
6839                .len(),
6840            2
6841        );
6842        let params = vec![0.2, -0.3, 1.1, -0.7];
6843        let expected = evaluator
6844            .evaluate_gradient_local(&params)
6845            .expect("evaluation should succeed")
6846            .iter()
6847            .zip(dataset.events_local().iter())
6848            .fold(
6849                DVector::zeros(params.len()),
6850                |mut accum, (gradient, event)| {
6851                    accum += gradient.map(|value| value.re).scale(event.weight());
6852                    accum
6853                },
6854            );
6855        let actual = evaluator
6856            .evaluate_weighted_gradient_sum_local(&params)
6857            .expect("evaluation should succeed");
6858        for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
6859            assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
6860        }
6861    }
6862
6863    #[test]
6864    fn test_evaluate_weighted_value_sum_local_matches_eventwise_baseline() {
6865        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6866        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6867        let c1 = CacheOnlyScalar::new("c1").unwrap();
6868        let c2 = CacheOnlyScalar::new("c2").unwrap();
6869        let c3 = CacheOnlyScalar::new("c3").unwrap();
6870        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6871        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6872        let dataset = Arc::new(test_dataset());
6873        let evaluator = expr.load(&dataset).unwrap();
6874        assert_eq!(
6875            evaluator
6876                .expression_precomputed_cached_integrals()
6877                .expect("integrals should exist")
6878                .len(),
6879            2
6880        );
6881        let params = vec![0.2, -0.3, 1.1, -0.7];
6882        let expected = evaluator
6883            .evaluate_local(&params)
6884            .expect("evaluation should succeed")
6885            .iter()
6886            .zip(dataset.events_local().iter())
6887            .fold(0.0, |accum, (value, event)| {
6888                accum + event.weight() * value.re
6889            });
6890        let actual = evaluator
6891            .evaluate_weighted_value_sum_local(&params)
6892            .expect("evaluation should succeed");
6893        assert_relative_eq!(actual, expected, epsilon = 1e-10);
6894    }
6895
6896    #[test]
6897    fn test_weighted_sums_match_hardcoded_reference_values() {
6898        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6899        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6900        let c1 = CacheOnlyScalar::new("c1").unwrap();
6901        let c2 = CacheOnlyScalar::new("c2").unwrap();
6902        let c3 = CacheOnlyScalar::new("c3").unwrap();
6903        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6904        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6905
6906        let metadata = Arc::new(DatasetMetadata::default());
6907        let dataset = Arc::new(Dataset::new_with_metadata(
6908            vec![
6909                Arc::new(EventData {
6910                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6911                    aux: vec![],
6912                    weight: 0.5,
6913                }),
6914                Arc::new(EventData {
6915                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6916                    aux: vec![],
6917                    weight: -1.25,
6918                }),
6919                Arc::new(EventData {
6920                    p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6921                    aux: vec![],
6922                    weight: 2.0,
6923                }),
6924            ],
6925            metadata,
6926        ));
6927        let evaluator = expr.load(&dataset).unwrap();
6928        let params = vec![0.7, -1.1, 0.9, -0.4];
6929
6930        let weighted_value_sum = evaluator
6931            .evaluate_weighted_value_sum_local(&params)
6932            .expect("evaluation should succeed");
6933        assert_relative_eq!(weighted_value_sum, 22.7725, epsilon = 1e-12);
6934
6935        let weighted_gradient_sum = evaluator
6936            .evaluate_weighted_gradient_sum_local(&params)
6937            .expect("evaluation should succeed");
6938        let free_parameters = evaluator
6939            .free_parameters()
6940            .into_iter()
6941            .map(|name| name.to_string())
6942            .collect::<Vec<_>>();
6943        assert_eq!(free_parameters, vec!["p1", "p2", "m1r", "m1i"]);
6944        let expected_gradient = [43.925, 7.25, 28.525, 0.0];
6945        assert_eq!(weighted_gradient_sum.len(), expected_gradient.len());
6946        for (actual, expected) in weighted_gradient_sum.iter().zip(expected_gradient.iter()) {
6947            assert_relative_eq!(*actual, *expected, epsilon = 1e-9);
6948        }
6949    }
6950    #[test]
6951    fn test_evaluate_weighted_gradient_sum_local_respects_signed_cached_terms() {
6952        let expr = Expression::one()
6953            - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6954                * &CacheOnlyScalar::new("k").unwrap());
6955        let dataset = Arc::new(test_dataset());
6956        let evaluator = expr.load(&dataset).unwrap();
6957        assert_eq!(
6958            evaluator
6959                .expression_precomputed_cached_integrals()
6960                .expect("integrals should exist")
6961                .len(),
6962            1
6963        );
6964        assert_eq!(
6965            evaluator
6966                .expression_precomputed_cached_integrals()
6967                .expect("integrals should exist")[0]
6968                .coefficient,
6969            -1
6970        );
6971        let params = vec![0.75];
6972        let expected = evaluator
6973            .evaluate_gradient_local(&params)
6974            .expect("evaluation should succeed")
6975            .iter()
6976            .zip(dataset.events_local().iter())
6977            .fold(
6978                DVector::zeros(params.len()),
6979                |mut accum, (gradient, event)| {
6980                    accum += gradient.map(|value| value.re).scale(event.weight());
6981                    accum
6982                },
6983            );
6984        let actual = evaluator
6985            .evaluate_weighted_gradient_sum_local(&params)
6986            .expect("evaluation should succeed");
6987        for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
6988            assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
6989        }
6990    }
6991    #[test]
6992    fn test_evaluate_weighted_value_sum_local_respects_signed_cached_terms() {
6993        let expr = Expression::one()
6994            - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6995                * &CacheOnlyScalar::new("k").unwrap());
6996        let dataset = Arc::new(test_dataset());
6997        let evaluator = expr.load(&dataset).unwrap();
6998        assert_eq!(
6999            evaluator
7000                .expression_precomputed_cached_integrals()
7001                .expect("integrals should exist")
7002                .len(),
7003            1
7004        );
7005        assert_eq!(
7006            evaluator
7007                .expression_precomputed_cached_integrals()
7008                .expect("integrals should exist")[0]
7009                .coefficient,
7010            -1
7011        );
7012        let params = vec![0.75];
7013        let expected = evaluator
7014            .evaluate_local(&params)
7015            .expect("evaluation should succeed")
7016            .iter()
7017            .zip(dataset.events_local().iter())
7018            .fold(0.0, |accum, (value, event)| {
7019                accum + event.weight() * value.re
7020            });
7021        let actual = evaluator
7022            .evaluate_weighted_value_sum_local(&params)
7023            .expect("evaluation should succeed");
7024        assert_relative_eq!(actual, expected, epsilon = 1e-10);
7025    }
7026    #[test]
7027    fn test_expression_ir_diagnostics_follow_activation_changes() {
7028        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
7029            * &CacheOnlyScalar::new("k").unwrap())
7030            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
7031        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
7032        let evaluator = expr.load(&dataset).unwrap();
7033
7034        let all_active = evaluator
7035            .expression_normalization_plan_explain()
7036            .expect("plan should exist");
7037        assert_eq!(all_active.cached_separable_nodes.len(), 1);
7038        assert_eq!(
7039            evaluator
7040                .expression_root_dependence()
7041                .expect("root dependence should exist"),
7042            ExpressionDependence::Mixed
7043        );
7044
7045        evaluator.isolate_many(&["p"]);
7046        let param_only = evaluator
7047            .expression_normalization_plan_explain()
7048            .expect("plan should exist");
7049        assert!(param_only.cached_separable_nodes.is_empty());
7050        assert_eq!(
7051            evaluator
7052                .expression_root_dependence()
7053                .expect("root dependence should exist"),
7054            ExpressionDependence::ParameterOnly
7055        );
7056    }
7057    #[test]
7058    fn test_expression_ir_specialization_cache_reuses_prior_mask_specializations() {
7059        let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
7060            * &CacheOnlyScalar::new("k").unwrap())
7061            + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
7062        let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
7063        let evaluator = expr.load(&dataset).unwrap();
7064
7065        let initial_compile_metrics = evaluator.expression_compile_metrics();
7066        assert!(initial_compile_metrics.initial_ir_compile_nanos > 0);
7067        assert!(initial_compile_metrics.initial_cached_integrals_nanos > 0);
7068        assert!(initial_compile_metrics.initial_lowering_nanos > 0);
7069        assert_eq!(initial_compile_metrics.specialization_cache_hits, 0);
7070        assert_eq!(initial_compile_metrics.specialization_cache_misses, 0);
7071        assert_eq!(
7072            initial_compile_metrics.specialization_lowering_cache_hits,
7073            0
7074        );
7075        assert_eq!(
7076            initial_compile_metrics.specialization_lowering_cache_misses,
7077            1
7078        );
7079
7080        assert_eq!(evaluator.specialization_cache_len(), 1);
7081        assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
7082        assert_eq!(
7083            evaluator.expression_specialization_metrics(),
7084            ExpressionSpecializationMetrics {
7085                cache_hits: 0,
7086                cache_misses: 1,
7087            }
7088        );
7089        let all_active_cached_integrals = evaluator
7090            .expression_precomputed_cached_integrals()
7091            .expect("integrals should exist");
7092
7093        evaluator.isolate_many(&["p"]);
7094        assert_eq!(evaluator.specialization_cache_len(), 2);
7095        assert_eq!(
7096            evaluator.expression_specialization_metrics(),
7097            ExpressionSpecializationMetrics {
7098                cache_hits: 0,
7099                cache_misses: 2,
7100            }
7101        );
7102        let after_cache_miss_metrics = evaluator.expression_compile_metrics();
7103        assert_eq!(after_cache_miss_metrics.specialization_cache_hits, 0);
7104        assert_eq!(after_cache_miss_metrics.specialization_cache_misses, 1);
7105        assert_eq!(
7106            after_cache_miss_metrics.specialization_lowering_cache_hits,
7107            0
7108        );
7109        assert_eq!(
7110            after_cache_miss_metrics.specialization_lowering_cache_misses,
7111            2
7112        );
7113        assert!(after_cache_miss_metrics.specialization_ir_compile_nanos > 0);
7114        assert!(after_cache_miss_metrics.specialization_cached_integrals_nanos > 0);
7115        assert!(after_cache_miss_metrics.specialization_lowering_nanos > 0);
7116        assert!(evaluator
7117            .expression_precomputed_cached_integrals()
7118            .expect("integrals should exist")
7119            .is_empty());
7120
7121        evaluator.activate_many(&["k", "m"]);
7122        assert_eq!(evaluator.specialization_cache_len(), 2);
7123        assert_eq!(
7124            evaluator.expression_specialization_metrics(),
7125            ExpressionSpecializationMetrics {
7126                cache_hits: 1,
7127                cache_misses: 2,
7128            }
7129        );
7130        assert_eq!(
7131            evaluator
7132                .expression_precomputed_cached_integrals()
7133                .expect("integrals should exist"),
7134            all_active_cached_integrals
7135        );
7136        let after_cache_hit_metrics = evaluator.expression_compile_metrics();
7137        assert_eq!(after_cache_hit_metrics.specialization_cache_hits, 1);
7138        assert_eq!(after_cache_hit_metrics.specialization_cache_misses, 1);
7139        assert_eq!(
7140            after_cache_hit_metrics.specialization_lowering_cache_hits,
7141            0
7142        );
7143        assert_eq!(
7144            after_cache_hit_metrics.specialization_lowering_cache_misses,
7145            2
7146        );
7147        assert!(after_cache_hit_metrics.specialization_cache_restore_nanos > 0);
7148    }
7149
7150    #[test]
7151    fn test_weighted_sums_match_baseline_after_activation_changes() {
7152        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
7153        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
7154        let c1 = CacheOnlyScalar::new("c1").unwrap();
7155        let c2 = CacheOnlyScalar::new("c2").unwrap();
7156        let c3 = CacheOnlyScalar::new("c3").unwrap();
7157        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
7158        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
7159        let dataset = Arc::new(test_dataset());
7160        let evaluator = expr.load(&dataset).unwrap();
7161        let params = vec![0.2, -0.3, 1.1, -0.7];
7162
7163        evaluator.isolate_many(&["p1", "c1", "m1", "c3"]);
7164
7165        let expected_value = evaluator
7166            .evaluate_local(&params)
7167            .expect("evaluation should succeed")
7168            .iter()
7169            .zip(dataset.events_local().iter())
7170            .fold(0.0, |accum, (value, event)| {
7171                accum + event.weight() * value.re
7172            });
7173        assert_relative_eq!(
7174            evaluator
7175                .evaluate_weighted_value_sum_local(&params)
7176                .expect("evaluation should succeed"),
7177            expected_value,
7178            epsilon = 1e-10
7179        );
7180    }
7181
7182    #[test]
7183    fn test_evaluate_local_does_not_depend_on_dataset_rows() {
7184        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
7185            .unwrap()
7186            .norm_sqr();
7187        let mut event1 = test_event();
7188        event1.p4s[0].t = 7.5;
7189        let mut event2 = test_event();
7190        event2.p4s[0].t = 8.25;
7191        let mut event3 = test_event();
7192        event3.p4s[0].t = 9.0;
7193        let dataset = Arc::new(Dataset::new_with_metadata(
7194            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
7195            Arc::new(DatasetMetadata::default()),
7196        ));
7197        let mut evaluator = expr.load(&dataset).unwrap();
7198        drop(dataset);
7199        let expected_len = evaluator.resources.read().caches.len();
7200        Arc::get_mut(&mut evaluator.dataset)
7201            .expect("evaluator should own dataset Arc in this test")
7202            .clear_events_local();
7203        let cached = evaluator
7204            .evaluate_local(&[1.25, -0.75])
7205            .expect("evaluation should succeed");
7206        assert_eq!(cached.len(), expected_len);
7207    }
7208
7209    #[test]
7210    fn test_evaluate_gradient_local_does_not_depend_on_dataset_rows() {
7211        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
7212            .unwrap()
7213            .norm_sqr();
7214        let mut event1 = test_event();
7215        event1.p4s[0].t = 7.5;
7216        let mut event2 = test_event();
7217        event2.p4s[0].t = 8.25;
7218        let mut event3 = test_event();
7219        event3.p4s[0].t = 9.0;
7220        let dataset = Arc::new(Dataset::new_with_metadata(
7221            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
7222            Arc::new(DatasetMetadata::default()),
7223        ));
7224        let mut evaluator = expr.load(&dataset).unwrap();
7225        drop(dataset);
7226        let expected_len = evaluator.resources.read().caches.len();
7227        Arc::get_mut(&mut evaluator.dataset)
7228            .expect("evaluator should own dataset Arc in this test")
7229            .clear_events_local();
7230        let cached = evaluator
7231            .evaluate_gradient_local(&[1.25, -0.75])
7232            .expect("evaluation should succeed");
7233        assert_eq!(cached.len(), expected_len);
7234    }
7235
7236    #[test]
7237    fn test_evaluate_with_gradient_local_matches_separate_paths() {
7238        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
7239            .unwrap()
7240            .norm_sqr();
7241        let dataset = Arc::new(Dataset::new(vec![
7242            Arc::new(test_event()),
7243            Arc::new(test_event()),
7244            Arc::new(test_event()),
7245        ]));
7246        let evaluator = expr.load(&dataset).unwrap();
7247        let params = [1.25, -0.75];
7248        let values = evaluator
7249            .evaluate_local(&params)
7250            .expect("evaluation should succeed");
7251        let gradients = evaluator
7252            .evaluate_gradient_local(&params)
7253            .expect("evaluation should succeed");
7254        let fused = evaluator
7255            .evaluate_with_gradient_local(&params)
7256            .expect("evaluation should succeed");
7257        assert_eq!(fused.len(), values.len());
7258        assert_eq!(fused.len(), gradients.len());
7259        for ((value_gradient, value), gradient) in
7260            fused.iter().zip(values.iter()).zip(gradients.iter())
7261        {
7262            let (fused_value, fused_gradient) = value_gradient;
7263            assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
7264            assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
7265            assert_eq!(fused_gradient.len(), gradient.len());
7266            for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
7267                assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
7268                assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
7269            }
7270        }
7271    }
7272
7273    #[test]
7274    fn test_evaluate_with_gradient_batch_local_matches_separate_paths() {
7275        let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
7276            .unwrap()
7277            .norm_sqr();
7278        let dataset = Arc::new(Dataset::new(vec![
7279            Arc::new(test_event()),
7280            Arc::new(test_event()),
7281            Arc::new(test_event()),
7282            Arc::new(test_event()),
7283        ]));
7284        let evaluator = expr.load(&dataset).unwrap();
7285        let params = [0.5, -1.25];
7286        let indices = vec![0, 2, 3];
7287        let values = evaluator
7288            .evaluate_batch_local(&params, &indices)
7289            .expect("evaluation should succeed");
7290        let gradients = evaluator
7291            .evaluate_gradient_batch_local(&params, &indices)
7292            .expect("evaluation should succeed");
7293        let fused = evaluator
7294            .evaluate_with_gradient_batch_local(&params, &indices)
7295            .expect("evaluation should succeed");
7296        assert_eq!(fused.len(), values.len());
7297        assert_eq!(fused.len(), gradients.len());
7298        for ((value_gradient, value), gradient) in
7299            fused.iter().zip(values.iter()).zip(gradients.iter())
7300        {
7301            let (fused_value, fused_gradient) = value_gradient;
7302            assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
7303            assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
7304            assert_eq!(fused_gradient.len(), gradient.len());
7305            for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
7306                assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
7307                assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
7308            }
7309        }
7310    }
7311
7312    #[test]
7313    fn test_precompute_all_columnar_populates_cache() {
7314        let mut event1 = test_event();
7315        event1.p4s[0].t = 7.5;
7316        let mut event2 = test_event();
7317        event2.p4s[0].t = 8.25;
7318        let mut event3 = test_event();
7319        event3.p4s[0].t = 9.0;
7320        let dataset = Dataset::new_with_metadata(
7321            vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
7322            Arc::new(DatasetMetadata::default()),
7323        );
7324        let mut amplitude = TestAmplitude {
7325            name: "test".to_string(),
7326            re: parameter!("real"),
7327            pid_re: ParameterID::default(),
7328            im: parameter!("imag"),
7329            pid_im: ParameterID::default(),
7330            beam_energy: Default::default(),
7331        };
7332        let mut resources = Resources::default();
7333        amplitude
7334            .register(&mut resources)
7335            .expect("test amplitude should register");
7336        resources.reserve_cache(dataset.n_events());
7337        amplitude.precompute_all(&dataset, &mut resources);
7338        for cache in &resources.caches {
7339            assert!(cache.get_scalar(amplitude.beam_energy) > 0.0);
7340        }
7341    }
7342
7343    #[cfg(feature = "mpi")]
7344    #[mpi_test(np = [2])]
7345    fn test_load_reserves_local_cache_size_in_mpi() {
7346        use crate::mpi::{finalize_mpi, get_world, use_mpi};
7347
7348        use_mpi(true);
7349        assert!(get_world().is_some(), "MPI world should be initialized");
7350
7351        let expr = ComplexScalar::new(
7352            "constant",
7353            parameter!("const_re", 2.0),
7354            parameter!("const_im", 3.0),
7355        )
7356        .expect("constant amplitude should construct");
7357        let events = vec![
7358            Arc::new(test_event()),
7359            Arc::new(test_event()),
7360            Arc::new(test_event()),
7361            Arc::new(test_event()),
7362        ];
7363        let dataset = Arc::new(Dataset::new_with_metadata(
7364            events,
7365            Arc::new(DatasetMetadata::default()),
7366        ));
7367        let evaluator = expr.load(&dataset).expect("evaluator should load");
7368        let local_events = dataset.n_events_local();
7369        let cache_len = evaluator.resources.read().caches.len();
7370
7371        assert_eq!(
7372            cache_len, local_events,
7373            "cache length must match local event count under MPI"
7374        );
7375        finalize_mpi();
7376    }
7377
7378    #[cfg(feature = "mpi")]
7379    #[mpi_test(np = [2])]
7380    fn test_expression_ir_cached_integrals_are_rank_local_in_mpi() {
7381        use crate::mpi::{finalize_mpi, get_world, use_mpi};
7382        use mpi::{collective::SystemOperation, topology::Communicator, traits::*};
7383
7384        use_mpi(true);
7385        let world = get_world().expect("MPI world should be initialized");
7386
7387        let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
7388            * &CacheOnlyScalar::new("k").unwrap();
7389        let events = vec![
7390            Arc::new(EventData {
7391                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
7392                aux: vec![],
7393                weight: 0.5,
7394            }),
7395            Arc::new(EventData {
7396                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
7397                aux: vec![],
7398                weight: 1.0,
7399            }),
7400            Arc::new(EventData {
7401                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
7402                aux: vec![],
7403                weight: 1.5,
7404            }),
7405            Arc::new(EventData {
7406                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
7407                aux: vec![],
7408                weight: 2.0,
7409            }),
7410            Arc::new(EventData {
7411                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
7412                aux: vec![],
7413                weight: 2.5,
7414            }),
7415            Arc::new(EventData {
7416                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
7417                aux: vec![],
7418                weight: 3.0,
7419            }),
7420            Arc::new(EventData {
7421                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
7422                aux: vec![],
7423                weight: 3.5,
7424            }),
7425            Arc::new(EventData {
7426                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
7427                aux: vec![],
7428                weight: 4.0,
7429            }),
7430        ];
7431        let dataset = Arc::new(Dataset::new_with_metadata(
7432            events,
7433            Arc::new(DatasetMetadata::default()),
7434        ));
7435        let evaluator = expr.load(&dataset).expect("evaluator should load");
7436        let cached_integrals = evaluator
7437            .expression_precomputed_cached_integrals()
7438            .expect("integrals should exist");
7439        assert_eq!(cached_integrals.len(), 1);
7440
7441        let local_expected = dataset.events_local().iter().fold(0.0, |acc, event| {
7442            acc + event.weight() * event.data().p4s[0].e()
7443        });
7444        let cached_local = cached_integrals[0].weighted_cache_sum;
7445        assert_relative_eq!(cached_local.re, local_expected, epsilon = 1e-12);
7446        assert_relative_eq!(cached_local.im, 0.0, epsilon = 1e-12);
7447
7448        let weighted_value_sum = evaluator
7449            .evaluate_weighted_value_sum_local(&[2.0])
7450            .expect("evaluate should succeed");
7451        assert_relative_eq!(weighted_value_sum, 2.0 * local_expected, epsilon = 1e-10);
7452
7453        let mut global_expected = 0.0;
7454        world.all_reduce_into(
7455            &local_expected,
7456            &mut global_expected,
7457            SystemOperation::sum(),
7458        );
7459        if world.size() > 1 {
7460            assert!(
7461                (cached_local.re - global_expected).abs() > 1e-12,
7462                "cached integral should remain rank-local before MPI reduction"
7463            );
7464        }
7465        finalize_mpi();
7466    }
7467
7468    #[cfg(feature = "mpi")]
7469    #[mpi_test(np = [2])]
7470    fn test_expression_ir_weighted_sum_mpi_matches_global_eventwise_baseline() {
7471        use crate::mpi::{finalize_mpi, get_world, use_mpi};
7472        use mpi::{collective::SystemOperation, traits::*};
7473
7474        use_mpi(true);
7475        let world = get_world().expect("MPI world should be initialized");
7476
7477        let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
7478        let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
7479        let c1 = CacheOnlyScalar::new("c1").unwrap();
7480        let c2 = CacheOnlyScalar::new("c2").unwrap();
7481        let c3 = CacheOnlyScalar::new("c3").unwrap();
7482        let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
7483        let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
7484        let events = vec![
7485            Arc::new(EventData {
7486                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
7487                aux: vec![],
7488                weight: 0.5,
7489            }),
7490            Arc::new(EventData {
7491                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
7492                aux: vec![],
7493                weight: -1.25,
7494            }),
7495            Arc::new(EventData {
7496                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
7497                aux: vec![],
7498                weight: 0.75,
7499            }),
7500            Arc::new(EventData {
7501                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
7502                aux: vec![],
7503                weight: 1.5,
7504            }),
7505            Arc::new(EventData {
7506                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
7507                aux: vec![],
7508                weight: 2.25,
7509            }),
7510            Arc::new(EventData {
7511                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
7512                aux: vec![],
7513                weight: -0.5,
7514            }),
7515            Arc::new(EventData {
7516                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
7517                aux: vec![],
7518                weight: 3.5,
7519            }),
7520            Arc::new(EventData {
7521                p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
7522                aux: vec![],
7523                weight: 1.25,
7524            }),
7525        ];
7526        let dataset = Arc::new(Dataset::new_with_metadata(
7527            events,
7528            Arc::new(DatasetMetadata::default()),
7529        ));
7530        let evaluator = expr.load(&dataset).expect("evaluator should load");
7531        let params = vec![0.2, -0.3, 1.1, -0.7];
7532
7533        let local_expected_value = evaluator
7534            .evaluate_local(&params)
7535            .expect("evaluate should succeed")
7536            .iter()
7537            .zip(dataset.events_local().iter())
7538            .fold(0.0, |accum, (value, event)| {
7539                accum + event.weight() * value.re
7540            });
7541        let mut global_expected_value = 0.0;
7542        world.all_reduce_into(
7543            &local_expected_value,
7544            &mut global_expected_value,
7545            SystemOperation::sum(),
7546        );
7547        let mpi_value = evaluator
7548            .evaluate_weighted_value_sum_mpi(&params, &world)
7549            .expect("evaluate should succeed");
7550        assert_relative_eq!(mpi_value, global_expected_value, epsilon = 1e-10);
7551
7552        let local_expected_gradient = evaluator
7553            .evaluate_gradient_local(&params)
7554            .expect("evaluate should succeed")
7555            .iter()
7556            .zip(dataset.events_local().iter())
7557            .fold(
7558                DVector::zeros(params.len()),
7559                |mut accum, (gradient, event)| {
7560                    accum += gradient.map(|value| value.re).scale(event.weight());
7561                    accum
7562                },
7563            );
7564        let mut global_expected_gradient = vec![0.0; local_expected_gradient.len()];
7565        world.all_reduce_into(
7566            local_expected_gradient.as_slice(),
7567            &mut global_expected_gradient,
7568            SystemOperation::sum(),
7569        );
7570        let mpi_gradient = evaluator
7571            .evaluate_weighted_gradient_sum_mpi(&params, &world)
7572            .expect("evaluate should succeed");
7573        for (actual, expected) in mpi_gradient.iter().zip(global_expected_gradient.iter()) {
7574            assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
7575        }
7576
7577        finalize_mpi();
7578    }
7579
7580    #[test]
7581    fn test_evaluate_local_succeeds_for_constant_amplitude() {
7582        let expr = ComplexScalar::new(
7583            "constant",
7584            parameter!("const_re", 2.0),
7585            parameter!("const_im", 3.0),
7586        )
7587        .unwrap();
7588        let dataset = Arc::new(Dataset::new_with_metadata(
7589            vec![Arc::new(test_event())],
7590            Arc::new(DatasetMetadata::default()),
7591        ));
7592        let evaluator = expr.load(&dataset).unwrap();
7593        let values = evaluator
7594            .evaluate_local(&[])
7595            .expect("evaluation should succeed");
7596        assert_eq!(values.len(), 1);
7597        let gradients = evaluator
7598            .evaluate_gradient_local(&[])
7599            .expect("evaluation should succeed");
7600        assert_eq!(gradients.len(), 1);
7601    }
7602
7603    #[test]
7604    fn test_constant_amplitude() {
7605        let expr = ComplexScalar::new(
7606            "constant",
7607            parameter!("const_re", 2.0),
7608            parameter!("const_im", 3.0),
7609        )
7610        .unwrap();
7611        let dataset = Arc::new(Dataset::new_with_metadata(
7612            vec![Arc::new(test_event())],
7613            Arc::new(DatasetMetadata::default()),
7614        ));
7615        let evaluator = expr.load(&dataset).unwrap();
7616        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7617        assert_eq!(result[0], Complex64::new(2.0, 3.0));
7618    }
7619
7620    #[test]
7621    fn test_parametric_amplitude() {
7622        let expr = ComplexScalar::new(
7623            "parametric",
7624            parameter!("test_param_re"),
7625            parameter!("test_param_im"),
7626        )
7627        .unwrap();
7628        let dataset = Arc::new(test_dataset());
7629        let evaluator = expr.load(&dataset).unwrap();
7630        let result = evaluator
7631            .evaluate(&[2.0, 3.0])
7632            .expect("evaluation should succeed");
7633        assert_eq!(result[0], Complex64::new(2.0, 3.0));
7634    }
7635
7636    #[test]
7637    fn test_expression_operations() {
7638        let expr1 = ComplexScalar::new(
7639            "const1",
7640            parameter!("const1_re", 2.0),
7641            parameter!("const1_im", 0.0),
7642        )
7643        .unwrap();
7644        let expr2 = ComplexScalar::new(
7645            "const2",
7646            parameter!("const2_re", 0.0),
7647            parameter!("const2_im", 1.0),
7648        )
7649        .unwrap();
7650        let expr3 = ComplexScalar::new(
7651            "const3",
7652            parameter!("const3_re", 3.0),
7653            parameter!("const3_im", 4.0),
7654        )
7655        .unwrap();
7656
7657        let dataset = Arc::new(test_dataset());
7658
7659        // Test (amp) addition
7660        let expr_add = &expr1 + &expr2;
7661        let result_add = expr_add
7662            .load(&dataset)
7663            .unwrap()
7664            .evaluate(&[])
7665            .expect("evaluation should succeed");
7666        assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
7667
7668        // Test (amp) subtraction
7669        let expr_sub = &expr1 - &expr2;
7670        let result_sub = expr_sub
7671            .load(&dataset)
7672            .unwrap()
7673            .evaluate(&[])
7674            .expect("evaluation should succeed");
7675        assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
7676
7677        // Test (amp) multiplication
7678        let expr_mul = &expr1 * &expr2;
7679        let result_mul = expr_mul
7680            .load(&dataset)
7681            .unwrap()
7682            .evaluate(&[])
7683            .expect("evaluation should succeed");
7684        assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
7685
7686        // Test (amp) division
7687        let expr_div = &expr1 / &expr3;
7688        let result_div = expr_div
7689            .load(&dataset)
7690            .unwrap()
7691            .evaluate(&[])
7692            .expect("evaluation should succeed");
7693        assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
7694
7695        // Test (amp) neg
7696        let expr_neg = -&expr3;
7697        let result_neg = expr_neg
7698            .load(&dataset)
7699            .unwrap()
7700            .evaluate(&[])
7701            .expect("evaluation should succeed");
7702        assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
7703
7704        // Test (expr) addition
7705        let expr_add2 = &expr_add + &expr_mul;
7706        let result_add2 = expr_add2
7707            .load(&dataset)
7708            .unwrap()
7709            .evaluate(&[])
7710            .expect("evaluation should succeed");
7711        assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
7712
7713        // Test (expr) subtraction
7714        let expr_sub2 = &expr_add - &expr_mul;
7715        let result_sub2 = expr_sub2
7716            .load(&dataset)
7717            .unwrap()
7718            .evaluate(&[])
7719            .expect("evaluation should succeed");
7720        assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
7721
7722        // Test (expr) multiplication
7723        let expr_mul2 = &expr_add * &expr_mul;
7724        let result_mul2 = expr_mul2
7725            .load(&dataset)
7726            .unwrap()
7727            .evaluate(&[])
7728            .expect("evaluation should succeed");
7729        assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
7730
7731        // Test (expr) division
7732        let expr_div2 = &expr_add / &expr_add2;
7733        let result_div2 = expr_div2
7734            .load(&dataset)
7735            .unwrap()
7736            .evaluate(&[])
7737            .expect("evaluation should succeed");
7738        assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
7739
7740        // Test (expr) neg
7741        let expr_neg2 = -&expr_mul2;
7742        let result_neg2 = expr_neg2
7743            .load(&dataset)
7744            .unwrap()
7745            .evaluate(&[])
7746            .expect("evaluation should succeed");
7747        assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
7748
7749        // Test (amp) real
7750        let expr_real = expr3.real();
7751        let result_real = expr_real
7752            .load(&dataset)
7753            .unwrap()
7754            .evaluate(&[])
7755            .expect("evaluation should succeed");
7756        assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
7757
7758        // Test (expr) real
7759        let expr_mul2_real = expr_mul2.real();
7760        let result_mul2_real = expr_mul2_real
7761            .load(&dataset)
7762            .unwrap()
7763            .evaluate(&[])
7764            .expect("evaluation should succeed");
7765        assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
7766
7767        // Test (amp) imag
7768        let expr_imag = expr3.imag();
7769        let result_imag = expr_imag
7770            .load(&dataset)
7771            .unwrap()
7772            .evaluate(&[])
7773            .expect("evaluation should succeed");
7774        assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
7775
7776        // Test (expr) imag
7777        let expr_mul2_imag = expr_mul2.imag();
7778        let result_mul2_imag = expr_mul2_imag
7779            .load(&dataset)
7780            .unwrap()
7781            .evaluate(&[])
7782            .expect("evaluation should succeed");
7783        assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
7784
7785        // Test (amp) conj
7786        let expr_conj = expr3.conj();
7787        let result_conj = expr_conj
7788            .load(&dataset)
7789            .unwrap()
7790            .evaluate(&[])
7791            .expect("evaluation should succeed");
7792        assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
7793
7794        // Test (expr) conj
7795        let expr_mul2_conj = expr_mul2.conj();
7796        let result_mul2_conj = expr_mul2_conj
7797            .load(&dataset)
7798            .unwrap()
7799            .evaluate(&[])
7800            .expect("evaluation should succeed");
7801        assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
7802
7803        // Test (amp) norm_sqr
7804        let expr_norm = expr1.norm_sqr();
7805        let result_norm = expr_norm
7806            .load(&dataset)
7807            .unwrap()
7808            .evaluate(&[])
7809            .expect("evaluation should succeed");
7810        assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
7811
7812        // Test (expr) norm_sqr
7813        let expr_mul2_norm = expr_mul2.norm_sqr();
7814        let result_mul2_norm = expr_mul2_norm
7815            .load(&dataset)
7816            .unwrap()
7817            .evaluate(&[])
7818            .expect("evaluation should succeed");
7819        assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
7820    }
7821
7822    #[test]
7823    fn test_amplitude_activation() {
7824        let expr1 = ComplexScalar::new(
7825            "const1",
7826            parameter!("const1_re_act", 1.0),
7827            parameter!("const1_im_act", 0.0),
7828        )
7829        .unwrap();
7830        let expr2 = ComplexScalar::new(
7831            "const2",
7832            parameter!("const2_re_act", 2.0),
7833            parameter!("const2_im_act", 0.0),
7834        )
7835        .unwrap();
7836
7837        let dataset = Arc::new(test_dataset());
7838        let expr = &expr1 + &expr2;
7839        let evaluator = expr.load(&dataset).unwrap();
7840
7841        // Test initial state (all active)
7842        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7843        assert_eq!(result[0], Complex64::new(3.0, 0.0));
7844
7845        // Test deactivation
7846        evaluator.deactivate_strict("const1").unwrap();
7847        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7848        assert_eq!(result[0], Complex64::new(2.0, 0.0));
7849
7850        // Test isolation
7851        evaluator.isolate_strict("const1").unwrap();
7852        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7853        assert_eq!(result[0], Complex64::new(1.0, 0.0));
7854
7855        // Test reactivation
7856        evaluator.activate_all();
7857        let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7858        assert_eq!(result[0], Complex64::new(3.0, 0.0));
7859    }
7860
7861    #[test]
7862    fn test_gradient() {
7863        let expr1 = ComplexScalar::new(
7864            "parametric_1",
7865            parameter!("test_param_re_1"),
7866            parameter!("test_param_im_1"),
7867        )
7868        .unwrap();
7869        let expr2 = ComplexScalar::new(
7870            "parametric_2",
7871            parameter!("test_param_re_2"),
7872            parameter!("test_param_im_2"),
7873        )
7874        .unwrap();
7875
7876        let dataset = Arc::new(test_dataset());
7877        let params = vec![2.0, 3.0, 4.0, 5.0];
7878
7879        let expr = &expr1 + &expr2;
7880        let evaluator = expr.load(&dataset).unwrap();
7881
7882        let gradient = evaluator
7883            .evaluate_gradient(&params)
7884            .expect("evaluation should succeed");
7885
7886        assert_relative_eq!(gradient[0][0].re, 1.0);
7887        assert_relative_eq!(gradient[0][0].im, 0.0);
7888        assert_relative_eq!(gradient[0][1].re, 0.0);
7889        assert_relative_eq!(gradient[0][1].im, 1.0);
7890        assert_relative_eq!(gradient[0][2].re, 1.0);
7891        assert_relative_eq!(gradient[0][2].im, 0.0);
7892        assert_relative_eq!(gradient[0][3].re, 0.0);
7893        assert_relative_eq!(gradient[0][3].im, 1.0);
7894
7895        let expr = &expr1 - &expr2;
7896        let evaluator = expr.load(&dataset).unwrap();
7897
7898        let gradient = evaluator
7899            .evaluate_gradient(&params)
7900            .expect("evaluation should succeed");
7901
7902        assert_relative_eq!(gradient[0][0].re, 1.0);
7903        assert_relative_eq!(gradient[0][0].im, 0.0);
7904        assert_relative_eq!(gradient[0][1].re, 0.0);
7905        assert_relative_eq!(gradient[0][1].im, 1.0);
7906        assert_relative_eq!(gradient[0][2].re, -1.0);
7907        assert_relative_eq!(gradient[0][2].im, 0.0);
7908        assert_relative_eq!(gradient[0][3].re, 0.0);
7909        assert_relative_eq!(gradient[0][3].im, -1.0);
7910
7911        let expr = &expr1 * &expr2;
7912        let evaluator = expr.load(&dataset).unwrap();
7913
7914        let gradient = evaluator
7915            .evaluate_gradient(&params)
7916            .expect("evaluation should succeed");
7917
7918        assert_relative_eq!(gradient[0][0].re, 4.0);
7919        assert_relative_eq!(gradient[0][0].im, 5.0);
7920        assert_relative_eq!(gradient[0][1].re, -5.0);
7921        assert_relative_eq!(gradient[0][1].im, 4.0);
7922        assert_relative_eq!(gradient[0][2].re, 2.0);
7923        assert_relative_eq!(gradient[0][2].im, 3.0);
7924        assert_relative_eq!(gradient[0][3].re, -3.0);
7925        assert_relative_eq!(gradient[0][3].im, 2.0);
7926
7927        let expr = &expr1 / &expr2;
7928        let evaluator = expr.load(&dataset).unwrap();
7929
7930        let gradient = evaluator
7931            .evaluate_gradient(&params)
7932            .expect("evaluation should succeed");
7933
7934        assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
7935        assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
7936        assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
7937        assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
7938        assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
7939        assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
7940        assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
7941        assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
7942
7943        let expr = -(&expr1 * &expr2);
7944        let evaluator = expr.load(&dataset).unwrap();
7945
7946        let gradient = evaluator
7947            .evaluate_gradient(&params)
7948            .expect("evaluation should succeed");
7949
7950        assert_relative_eq!(gradient[0][0].re, -4.0);
7951        assert_relative_eq!(gradient[0][0].im, -5.0);
7952        assert_relative_eq!(gradient[0][1].re, 5.0);
7953        assert_relative_eq!(gradient[0][1].im, -4.0);
7954        assert_relative_eq!(gradient[0][2].re, -2.0);
7955        assert_relative_eq!(gradient[0][2].im, -3.0);
7956        assert_relative_eq!(gradient[0][3].re, 3.0);
7957        assert_relative_eq!(gradient[0][3].im, -2.0);
7958
7959        let expr = (&expr1 * &expr2).real();
7960        let evaluator = expr.load(&dataset).unwrap();
7961
7962        let gradient = evaluator
7963            .evaluate_gradient(&params)
7964            .expect("evaluation should succeed");
7965
7966        assert_relative_eq!(gradient[0][0].re, 4.0);
7967        assert_relative_eq!(gradient[0][0].im, 0.0);
7968        assert_relative_eq!(gradient[0][1].re, -5.0);
7969        assert_relative_eq!(gradient[0][1].im, 0.0);
7970        assert_relative_eq!(gradient[0][2].re, 2.0);
7971        assert_relative_eq!(gradient[0][2].im, 0.0);
7972        assert_relative_eq!(gradient[0][3].re, -3.0);
7973        assert_relative_eq!(gradient[0][3].im, 0.0);
7974
7975        let expr = (&expr1 * &expr2).imag();
7976        let evaluator = expr.load(&dataset).unwrap();
7977
7978        let gradient = evaluator
7979            .evaluate_gradient(&params)
7980            .expect("evaluation should succeed");
7981
7982        assert_relative_eq!(gradient[0][0].re, 5.0);
7983        assert_relative_eq!(gradient[0][0].im, 0.0);
7984        assert_relative_eq!(gradient[0][1].re, 4.0);
7985        assert_relative_eq!(gradient[0][1].im, 0.0);
7986        assert_relative_eq!(gradient[0][2].re, 3.0);
7987        assert_relative_eq!(gradient[0][2].im, 0.0);
7988        assert_relative_eq!(gradient[0][3].re, 2.0);
7989        assert_relative_eq!(gradient[0][3].im, 0.0);
7990
7991        let expr = (&expr1 * &expr2).conj();
7992        let evaluator = expr.load(&dataset).unwrap();
7993
7994        let gradient = evaluator
7995            .evaluate_gradient(&params)
7996            .expect("evaluation should succeed");
7997
7998        assert_relative_eq!(gradient[0][0].re, 4.0);
7999        assert_relative_eq!(gradient[0][0].im, -5.0);
8000        assert_relative_eq!(gradient[0][1].re, -5.0);
8001        assert_relative_eq!(gradient[0][1].im, -4.0);
8002        assert_relative_eq!(gradient[0][2].re, 2.0);
8003        assert_relative_eq!(gradient[0][2].im, -3.0);
8004        assert_relative_eq!(gradient[0][3].re, -3.0);
8005        assert_relative_eq!(gradient[0][3].im, -2.0);
8006
8007        let expr = (&expr1 * &expr2).norm_sqr();
8008        let evaluator = expr.load(&dataset).unwrap();
8009
8010        let gradient = evaluator
8011            .evaluate_gradient(&params)
8012            .expect("evaluation should succeed");
8013
8014        assert_relative_eq!(gradient[0][0].re, 164.0);
8015        assert_relative_eq!(gradient[0][0].im, 0.0);
8016        assert_relative_eq!(gradient[0][1].re, 246.0);
8017        assert_relative_eq!(gradient[0][1].im, 0.0);
8018        assert_relative_eq!(gradient[0][2].re, 104.0);
8019        assert_relative_eq!(gradient[0][2].im, 0.0);
8020        assert_relative_eq!(gradient[0][3].re, 130.0);
8021        assert_relative_eq!(gradient[0][3].im, 0.0);
8022    }
8023
8024    #[test]
8025    fn test_expression_function_gradients() {
8026        let expr1 = ComplexScalar::new(
8027            "function_parametric_1",
8028            parameter!("function_test_param_re_1"),
8029            parameter!("function_test_param_im_1"),
8030        )
8031        .unwrap();
8032        let expr2 = ComplexScalar::new(
8033            "function_parametric_2",
8034            parameter!("function_test_param_re_2"),
8035            parameter!("function_test_param_im_2"),
8036        )
8037        .unwrap();
8038
8039        let sin = expr1.sin();
8040        let cos = expr1.cos();
8041        let trig = &sin * &cos;
8042        let pow = expr1.pow(&expr2);
8043        let mut expr = expr1.sqrt();
8044        expr = &expr + &expr1.exp();
8045        expr = &expr + &expr1.powi(2);
8046        expr = &expr + &expr1.powf(1.7);
8047        expr = &expr + &trig;
8048        expr = &expr + &expr1.log();
8049        expr = &expr + &expr1.cis();
8050        expr = &expr + &pow;
8051
8052        let dataset = Arc::new(test_dataset());
8053        let evaluator = expr.load(&dataset).unwrap();
8054        let params = vec![2.0, 0.5, 1.2, -0.3];
8055        let gradient = evaluator
8056            .evaluate_gradient(&params)
8057            .expect("evaluation should succeed");
8058        let eps = 1e-6;
8059
8060        for param_index in 0..params.len() {
8061            let mut plus = params.clone();
8062            plus[param_index] += eps;
8063            let mut minus = params.clone();
8064            minus[param_index] -= eps;
8065            let finite_diff = (evaluator
8066                .evaluate(&plus)
8067                .expect("evaluation should succeed")[0]
8068                - evaluator
8069                    .evaluate(&minus)
8070                    .expect("evaluation should succeed")[0])
8071                / Complex64::new(2.0 * eps, 0.0);
8072
8073            assert_relative_eq!(
8074                gradient[0][param_index].re,
8075                finite_diff.re,
8076                epsilon = 1e-6,
8077                max_relative = 1e-6
8078            );
8079            assert_relative_eq!(
8080                gradient[0][param_index].im,
8081                finite_diff.im,
8082                epsilon = 1e-6,
8083                max_relative = 1e-6
8084            );
8085        }
8086    }
8087
8088    #[test]
8089    fn test_zeros_and_ones() {
8090        let amp = ComplexScalar::new(
8091            "parametric",
8092            parameter!("test_param_re"),
8093            parameter!("fixed_two", 2.0),
8094        )
8095        .unwrap();
8096        let dataset = Arc::new(test_dataset());
8097        let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
8098        let evaluator = expr.load(&dataset).unwrap();
8099
8100        let params = vec![2.0];
8101        let value = evaluator
8102            .evaluate(&params)
8103            .expect("evaluation should succeed");
8104        let gradient = evaluator
8105            .evaluate_gradient(&params)
8106            .expect("evaluation should succeed");
8107
8108        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the value should be x^2 + 4
8109        assert_relative_eq!(value[0].re, 8.0);
8110        assert_relative_eq!(value[0].im, 0.0);
8111
8112        // For |f(x) * 1 + 0|^2 where f(x) = x+2i, the derivative should be 2x
8113        assert_relative_eq!(gradient[0][0].re, 4.0);
8114        assert_relative_eq!(gradient[0][0].im, 0.0);
8115    }
8116    #[test]
8117    fn test_default_build_uses_lowered_expression_runtime() {
8118        let expr = ComplexScalar::new(
8119            "opt_in_gate",
8120            parameter!("opt_in_gate_re", 2.0),
8121            parameter!("opt_in_gate_im", 0.0),
8122        )
8123        .unwrap()
8124        .norm_sqr();
8125        let dataset = Arc::new(test_dataset());
8126        let evaluator = expr.load(&dataset).unwrap();
8127
8128        let diagnostics = evaluator.expression_runtime_diagnostics();
8129        assert!(diagnostics.ir_planning_enabled);
8130        assert!(diagnostics.lowered_value_program_present);
8131        assert!(diagnostics.lowered_gradient_program_present);
8132        assert!(diagnostics.lowered_value_gradient_program_present);
8133        assert_eq!(
8134            evaluator.evaluate(&[]).expect("evaluation should succeed")[0],
8135            Complex64::new(4.0, 0.0)
8136        );
8137    }
8138
8139    #[test]
8140    fn parameter_name_only_creates_free_parameter() {
8141        let p = parameter!("mass");
8142
8143        assert_eq!(p.name(), "mass");
8144        assert_eq!(p.fixed(), None);
8145        assert_eq!(p.initial(), None);
8146        assert_eq!(p.bounds(), (None, None));
8147        assert_eq!(p.unit(), None);
8148        assert_eq!(p.latex(), None);
8149        assert_eq!(p.description(), None);
8150        assert!(p.is_free());
8151        assert!(!p.is_fixed());
8152    }
8153
8154    #[test]
8155    fn parameter_name_and_value_creates_fixed_parameter() {
8156        let p = parameter!("width", 0.15);
8157
8158        assert_eq!(p.name(), "width");
8159        assert_eq!(p.fixed(), Some(0.15));
8160        assert_eq!(p.initial(), Some(0.15));
8161        assert!(p.is_fixed());
8162        assert!(!p.is_free());
8163    }
8164
8165    #[test]
8166    fn keyword_initial_sets_initial_only() {
8167        let p = parameter!("alpha", initial: 1.25);
8168
8169        assert_eq!(p.name(), "alpha");
8170        assert_eq!(p.fixed(), None);
8171        assert_eq!(p.initial(), Some(1.25));
8172        assert_eq!(p.bounds(), (None, None));
8173        assert!(p.is_free());
8174    }
8175
8176    #[test]
8177    fn keyword_fixed_sets_fixed_and_initial() {
8178        let p = parameter!("beta", fixed: 2.5);
8179
8180        assert_eq!(p.name(), "beta");
8181        assert_eq!(p.fixed(), Some(2.5));
8182        assert_eq!(p.initial(), Some(2.5));
8183        assert!(p.is_fixed());
8184    }
8185
8186    #[test]
8187    fn bounds_accept_plain_numbers() {
8188        let p = parameter!("x", bounds: (0.0, 10.0));
8189
8190        assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
8191    }
8192
8193    #[test]
8194    fn bounds_accept_none_and_number() {
8195        let p = parameter!("x", bounds: (None, 10.0));
8196
8197        assert_eq!(p.bounds(), (None, Some(10.0)));
8198    }
8199
8200    #[test]
8201    fn bounds_accept_number_and_none() {
8202        let p = parameter!("x", bounds: (-1.0, None));
8203
8204        assert_eq!(p.bounds(), (Some(-1.0), None));
8205    }
8206
8207    #[test]
8208    fn bounds_accept_both_none() {
8209        let p = parameter!("x", bounds: (None, None));
8210
8211        assert_eq!(p.bounds(), (None, None));
8212    }
8213
8214    #[test]
8215    fn bounds_accept_arbitrary_expressions() {
8216        let lo = 1.0;
8217        let hi = 2.0 * 3.0;
8218        let p = parameter!("x", bounds: (lo - 0.5, hi));
8219
8220        assert_eq!(p.bounds(), (Some(0.5), Some(6.0)));
8221    }
8222
8223    #[test]
8224    fn multiple_keyword_arguments_work_together() {
8225        let p = parameter!(
8226            "gamma",
8227            initial: 1.0,
8228            bounds: (0.0, 5.0),
8229            unit: "GeV",
8230            latex: r"\gamma",
8231            description: "test parameter",
8232        );
8233
8234        assert_eq!(p.name(), "gamma");
8235        assert_eq!(p.fixed(), None);
8236        assert_eq!(p.initial(), Some(1.0));
8237        assert_eq!(p.bounds(), (Some(0.0), Some(5.0)));
8238        assert_eq!(p.unit().as_deref(), Some("GeV"));
8239        assert_eq!(p.latex().as_deref(), Some(r"\gamma"));
8240        assert_eq!(p.description().as_deref(), Some("test parameter"));
8241    }
8242
8243    #[test]
8244    fn fixed_can_be_combined_with_other_fields() {
8245        let p = parameter!(
8246            "delta",
8247            fixed: 3.0,
8248            bounds: (0.0, 10.0),
8249            unit: "rad",
8250        );
8251
8252        assert_eq!(p.name(), "delta");
8253        assert_eq!(p.fixed(), Some(3.0));
8254        assert_eq!(p.initial(), Some(3.0));
8255        assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
8256        assert_eq!(p.unit().as_deref(), Some("rad"));
8257    }
8258
8259    #[test]
8260    fn trailing_comma_is_accepted() {
8261        let p = parameter!(
8262            "eps",
8263            initial: 0.5,
8264            bounds: (None, 1.0),
8265            unit: "arb",
8266        );
8267
8268        assert_eq!(p.initial(), Some(0.5));
8269        assert_eq!(p.bounds(), (None, Some(1.0)));
8270        assert_eq!(p.unit().as_deref(), Some("arb"));
8271    }
8272
8273    #[test]
8274    fn test_parameter_registration() {
8275        let expr = ComplexScalar::new(
8276            "parametric",
8277            parameter!("test_param_re"),
8278            parameter!("fixed_two", 2.0),
8279        )
8280        .unwrap();
8281        let parameters = expr.free_parameters();
8282        assert_eq!(parameters.len(), 1);
8283        assert_eq!(parameters[0], "test_param_re");
8284    }
8285
8286    #[test]
8287    #[should_panic(expected = "refers to different underlying amplitudes")]
8288    fn test_duplicate_amplitude_registration() {
8289        let amp1 = ComplexScalar::new(
8290            "same_name",
8291            parameter!("dup_re1", 1.0),
8292            parameter!("dup_im1", 0.0),
8293        )
8294        .unwrap();
8295        let amp2 = ComplexScalar::new(
8296            "same_name",
8297            parameter!("dup_re2", 2.0),
8298            parameter!("dup_im2", 0.0),
8299        )
8300        .unwrap();
8301        let _expr = amp1 + amp2;
8302    }
8303
8304    #[test]
8305    fn test_tree_printing() {
8306        let amp1 = ComplexScalar::new(
8307            "parametric_1",
8308            parameter!("test_param_re_1"),
8309            parameter!("test_param_im_1"),
8310        )
8311        .unwrap();
8312        let amp2 = ComplexScalar::new(
8313            "parametric_2",
8314            parameter!("test_param_re_2"),
8315            parameter!("test_param_im_2"),
8316        )
8317        .unwrap();
8318        let expr =
8319            &amp1.real() + &amp2.conj().imag() + Expression::one() * Complex64::new(-1.4, 2.0)
8320                - Expression::zero() / 1.0
8321                + (&amp1 * &amp2).norm_sqr();
8322        assert_eq!(
8323            expr.to_string(),
8324            "+
8325├─ -
8326│  ├─ +
8327│  │  ├─ +
8328│  │  │  ├─ Re
8329│  │  │  │  └─ parametric_1(id=0)
8330│  │  │  └─ Im
8331│  │  │     └─ *
8332│  │  │        └─ parametric_2(id=1)
8333│  │  └─ ×
8334│  │     ├─ 1 (exact)
8335│  │     └─ -1.4+2i
8336│  └─ ÷
8337│     ├─ 0 (exact)
8338│     └─ 1 (exact)
8339└─ NormSqr
8340   └─ ×
8341      ├─ parametric_1(id=0)
8342      └─ parametric_2(id=1)
8343"
8344        );
8345    }
8346}