1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 sync::{
5 atomic::{AtomicU64, Ordering},
6 Arc,
7 },
8 time::Instant,
9};
10
11use auto_ops::*;
12use nalgebra::DVector;
13use num::complex::Complex64;
14use parking_lot::RwLock;
15#[cfg(feature = "rayon")]
16use rayon::prelude::*;
17use serde::{Deserialize, Serialize};
18
19use super::{ir, lowered};
20
21static COMPUTE_AMPLITUDE_COUNTER: AtomicU64 = AtomicU64::new(0);
22static AMPLITUDE_USE_SITE_COUNTER: AtomicU64 = AtomicU64::new(0);
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
25struct ComputeAmplitudeId(u64);
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
28struct AmplitudeUseSiteId(u64);
29
30fn next_compute_amplitude_id() -> ComputeAmplitudeId {
31 ComputeAmplitudeId(COMPUTE_AMPLITUDE_COUNTER.fetch_add(1, Ordering::Relaxed))
32}
33
34fn next_amplitude_use_site_id() -> AmplitudeUseSiteId {
35 AmplitudeUseSiteId(AMPLITUDE_USE_SITE_COUNTER.fetch_add(1, Ordering::Relaxed))
36}
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum ExpressionDependence {
40 ParameterOnly,
42 CacheOnly,
44 Mixed,
46}
47impl From<ir::DependenceClass> for ExpressionDependence {
48 fn from(value: ir::DependenceClass) -> Self {
49 match value {
50 ir::DependenceClass::ParameterOnly => Self::ParameterOnly,
51 ir::DependenceClass::CacheOnly => Self::CacheOnly,
52 ir::DependenceClass::Mixed => Self::Mixed,
53 }
54 }
55}
56#[derive(Clone, Debug, PartialEq, Eq)]
57pub struct NormalizationPlanExplain {
59 pub root_dependence: ExpressionDependence,
61 pub warnings: Vec<String>,
63 pub separable_mul_candidate_nodes: Vec<usize>,
65 pub cached_separable_nodes: Vec<usize>,
67 pub residual_terms: Vec<usize>,
69}
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub struct NormalizationExecutionSetsExplain {
73 pub cached_parameter_amplitudes: Vec<usize>,
75 pub cached_cache_amplitudes: Vec<usize>,
77 pub residual_amplitudes: Vec<usize>,
79}
80#[derive(Clone, Debug, PartialEq)]
81pub struct PrecomputedCachedIntegral {
83 pub mul_node_index: usize,
85 pub parameter_node_index: usize,
87 pub cache_node_index: usize,
89 pub coefficient: i32,
91 pub weighted_cache_sum: Complex64,
93}
94#[derive(Clone, Debug, PartialEq)]
95pub struct PrecomputedCachedIntegralGradientTerm {
97 pub mul_node_index: usize,
99 pub parameter_node_index: usize,
101 pub cache_node_index: usize,
103 pub coefficient: i32,
105 pub weighted_gradient: DVector<Complex64>,
107}
108#[derive(Clone, Debug, PartialEq, Eq, Hash)]
109struct CachedIntegralCacheKey {
110 active_mask: Vec<bool>,
111 n_events_local: usize,
112 weights_local_len: usize,
113 weighted_sum_bits: u64,
114 weights_ptr: usize,
115}
116#[derive(Clone, Debug)]
117struct CachedIntegralCacheState {
118 key: CachedIntegralCacheKey,
119 expression_ir: ir::ExpressionIR,
120 values: Vec<PrecomputedCachedIntegral>,
121 execution_sets: ir::NormalizationExecutionSets,
122}
123#[derive(Clone, Debug)]
124struct LoweredArtifactCacheState {
125 parameter_node_indices: Vec<usize>,
126 mul_node_indices: Vec<usize>,
127 lowered_parameter_factors: Vec<Option<lowered::LoweredFactorRuntime>>,
128 residual_runtime: Option<lowered::LoweredExpressionRuntime>,
129 lowered_runtime: lowered::LoweredExpressionRuntime,
130}
131#[derive(Clone)]
132struct ExpressionSpecializationState {
133 cached_integrals: Arc<CachedIntegralCacheState>,
134 lowered_artifacts: Arc<LoweredArtifactCacheState>,
135}
136#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
137pub struct ExpressionSpecializationMetrics {
139 pub cache_hits: usize,
141 pub cache_misses: usize,
143}
144#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
145pub struct ExpressionCompileMetrics {
147 pub initial_ir_compile_nanos: u64,
149 pub initial_cached_integrals_nanos: u64,
151 pub initial_lowering_nanos: u64,
153 pub specialization_cache_hits: usize,
155 pub specialization_cache_misses: usize,
157 pub specialization_ir_compile_nanos: u64,
159 pub specialization_cached_integrals_nanos: u64,
161 pub specialization_lowering_nanos: u64,
163 pub specialization_lowering_cache_hits: usize,
165 pub specialization_lowering_cache_misses: usize,
167 pub specialization_cache_restore_nanos: u64,
169}
170impl From<ir::NormalizationPlanExplain> for NormalizationPlanExplain {
171 fn from(value: ir::NormalizationPlanExplain) -> Self {
172 Self {
173 root_dependence: value.root_dependence.into(),
174 warnings: value.warnings,
175 separable_mul_candidate_nodes: value
176 .separable_mul_candidates
177 .into_iter()
178 .map(|candidate| candidate.node_index)
179 .collect(),
180 cached_separable_nodes: value.cached_separable_nodes,
181 residual_terms: value.residual_terms,
182 }
183 }
184}
185impl From<ir::NormalizationExecutionSets> for NormalizationExecutionSetsExplain {
186 fn from(value: ir::NormalizationExecutionSets) -> Self {
187 Self {
188 cached_parameter_amplitudes: value.cached_parameter_amplitudes,
189 cached_cache_amplitudes: value.cached_cache_amplitudes,
190 residual_amplitudes: value.residual_amplitudes,
191 }
192 }
193}
194impl From<ExpressionDependence> for ir::DependenceClass {
195 fn from(value: ExpressionDependence) -> Self {
196 match value {
197 ExpressionDependence::ParameterOnly => Self::ParameterOnly,
198 ExpressionDependence::CacheOnly => Self::CacheOnly,
199 ExpressionDependence::Mixed => Self::Mixed,
200 }
201 }
202}
203
204#[cfg(feature = "mpi")]
205use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
206
207#[cfg(feature = "mpi")]
208use crate::mpi::LadduMPI;
209#[cfg(feature = "execution-context-prototype")]
210use crate::ExecutionContext;
211#[cfg(all(feature = "execution-context-prototype", feature = "rayon"))]
212use crate::ThreadPolicy;
213use crate::{
214 amplitude::{Amplitude, AmplitudeUseSite},
215 data::Dataset,
216 parameters::ParameterMap,
217 resources::{Cache, Parameters, Resources},
218 LadduError, LadduResult,
219};
220
221#[allow(missing_docs)]
223#[derive(Clone, Serialize, Deserialize, Default)]
224pub struct Expression {
225 registry: ExpressionRegistry,
226 tree: ExpressionNode,
227}
228
229#[derive(Clone, Serialize, Deserialize)]
230#[allow(missing_docs)]
231#[derive(Default)]
232pub struct ExpressionRegistry {
233 amplitudes: Vec<Box<dyn Amplitude>>,
234 amplitude_names: Vec<String>,
235 amplitude_ids: Vec<ComputeAmplitudeId>,
236 amplitude_use_sites: Vec<AmplitudeUseSite>,
237 amplitude_use_site_ids: Vec<AmplitudeUseSiteId>,
238 resources: Resources,
239}
240
241impl ExpressionRegistry {
242 fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
243 let mut resources = Resources::default();
244 let aid = amplitude.register(&mut resources)?;
245 let compute_id = next_compute_amplitude_id();
246 let use_site_id = next_amplitude_use_site_id();
247 resources.configure_amplitude_tags(std::slice::from_ref(&aid.0));
248 Ok(Self {
249 amplitudes: vec![amplitude],
250 amplitude_names: vec![aid.0.display_label()],
251 amplitude_ids: vec![compute_id],
252 amplitude_use_sites: vec![AmplitudeUseSite {
253 amplitude_index: 0,
254 tags: aid.0,
255 }],
256 amplitude_use_site_ids: vec![use_site_id],
257 resources,
258 })
259 }
260
261 fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
262 let mut resources = Resources::default();
263 let mut amplitudes = Vec::new();
264 let mut amplitude_ids = Vec::new();
265 let mut compute_id_to_index = HashMap::new();
266 let mut semantic_key_to_index = HashMap::new();
267
268 let mut left_compute_map = Vec::with_capacity(self.amplitudes.len());
269 for (amp_index, (amp, amp_id)) in
270 self.amplitudes.iter().zip(&self.amplitude_ids).enumerate()
271 {
272 let semantic_key = amp.semantic_key();
273 let mut cloned_amp = dyn_clone::clone_box(&**amp);
274 let aid = cloned_amp.register(&mut resources)?;
275 if let Some(key) = semantic_key.clone() {
276 semantic_key_to_index.insert(key, aid.1);
277 }
278 amplitudes.push(cloned_amp);
279 amplitude_ids.push(*amp_id);
280 compute_id_to_index.insert(*amp_id, aid.1);
281 left_compute_map.push(aid.1);
282 debug_assert_eq!(amp_index, aid.1);
283 }
284
285 let mut right_compute_map = Vec::with_capacity(other.amplitudes.len());
286 for (amp, amp_id) in other.amplitudes.iter().zip(&other.amplitude_ids) {
287 if let Some(existing) = compute_id_to_index.get(amp_id) {
288 right_compute_map.push(*existing);
289 continue;
290 }
291 let incoming_semantic_key = amp.semantic_key();
292 if let Some(existing) = incoming_semantic_key
293 .as_ref()
294 .and_then(|key| semantic_key_to_index.get(key))
295 {
296 right_compute_map.push(*existing);
297 continue;
298 }
299 let mut cloned_amp = dyn_clone::clone_box(&**amp);
300 let aid = cloned_amp.register(&mut resources)?;
301 if let Some(key) = incoming_semantic_key.clone() {
302 semantic_key_to_index.insert(key, aid.1);
303 }
304 amplitudes.push(cloned_amp);
305 amplitude_ids.push(*amp_id);
306 compute_id_to_index.insert(*amp_id, aid.1);
307 right_compute_map.push(aid.1);
308 }
309
310 let mut amplitude_use_sites = Vec::new();
311 let mut amplitude_use_site_ids = Vec::new();
312 let mut amplitude_names = Vec::new();
313 let mut use_site_id_to_index = HashMap::new();
314
315 let mut left_map = Vec::with_capacity(self.amplitude_use_sites.len());
316 for (use_site, use_site_id) in self
317 .amplitude_use_sites
318 .iter()
319 .zip(&self.amplitude_use_site_ids)
320 {
321 let mapped_index = left_compute_map[use_site.amplitude_index];
322 let new_index = amplitude_use_sites.len();
323 left_map.push(new_index);
324 use_site_id_to_index.insert(*use_site_id, new_index);
325 amplitude_use_site_ids.push(*use_site_id);
326 amplitude_names.push(use_site.tags.display_label());
327 amplitude_use_sites.push(AmplitudeUseSite {
328 amplitude_index: mapped_index,
329 tags: use_site.tags.clone(),
330 });
331 }
332
333 let mut right_map = Vec::with_capacity(other.amplitude_use_sites.len());
334 for (use_site, use_site_id) in other
335 .amplitude_use_sites
336 .iter()
337 .zip(&other.amplitude_use_site_ids)
338 {
339 if let Some(existing) = use_site_id_to_index.get(use_site_id) {
340 right_map.push(*existing);
341 continue;
342 }
343 let mapped_index = right_compute_map[use_site.amplitude_index];
344 let new_index = amplitude_use_sites.len();
345 right_map.push(new_index);
346 use_site_id_to_index.insert(*use_site_id, new_index);
347 amplitude_use_site_ids.push(*use_site_id);
348 amplitude_names.push(use_site.tags.display_label());
349 amplitude_use_sites.push(AmplitudeUseSite {
350 amplitude_index: mapped_index,
351 tags: use_site.tags.clone(),
352 });
353 }
354 let use_site_tags = amplitude_use_sites
355 .iter()
356 .map(|use_site| use_site.tags.clone())
357 .collect::<Vec<_>>();
358 resources.configure_amplitude_tags(&use_site_tags);
359
360 Ok((
361 Self {
362 amplitudes,
363 amplitude_names,
364 amplitude_ids,
365 amplitude_use_sites,
366 amplitude_use_site_ids,
367 resources,
368 },
369 left_map,
370 right_map,
371 ))
372 }
373}
374
375#[allow(missing_docs)]
377#[derive(Clone, Serialize, Deserialize, Default, Debug)]
378pub enum ExpressionNode {
379 #[default]
380 Zero,
382 One,
384 Constant(f64),
386 ComplexConstant(Complex64),
388 Amp(usize),
390 Add(Box<ExpressionNode>, Box<ExpressionNode>),
392 Sub(Box<ExpressionNode>, Box<ExpressionNode>),
394 Mul(Box<ExpressionNode>, Box<ExpressionNode>),
396 Div(Box<ExpressionNode>, Box<ExpressionNode>),
398 Neg(Box<ExpressionNode>),
400 Real(Box<ExpressionNode>),
402 Imag(Box<ExpressionNode>),
404 Conj(Box<ExpressionNode>),
406 NormSqr(Box<ExpressionNode>),
408 Sqrt(Box<ExpressionNode>),
409 Pow(Box<ExpressionNode>, Box<ExpressionNode>),
410 PowI(Box<ExpressionNode>, i32),
411 PowF(Box<ExpressionNode>, f64),
412 Exp(Box<ExpressionNode>),
413 Sin(Box<ExpressionNode>),
414 Cos(Box<ExpressionNode>),
415 Log(Box<ExpressionNode>),
416 Cis(Box<ExpressionNode>),
417}
418
419impl ExpressionNode {
420 fn remap(&self, mapping: &[usize]) -> Self {
421 match self {
422 Self::Amp(idx) => Self::Amp(mapping[*idx]),
423 Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
424 Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
425 Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
426 Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
427 Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
428 Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
429 Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
430 Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
431 Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
432 Self::Zero => Self::Zero,
433 Self::One => Self::One,
434 Self::Constant(v) => Self::Constant(*v),
435 Self::ComplexConstant(v) => Self::ComplexConstant(*v),
436 Self::Sqrt(a) => Self::Sqrt(Box::new(a.remap(mapping))),
437 Self::Pow(a, b) => Self::Pow(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
438 Self::PowI(a, power) => Self::PowI(Box::new(a.remap(mapping)), *power),
439 Self::PowF(a, power) => Self::PowF(Box::new(a.remap(mapping)), *power),
440 Self::Exp(a) => Self::Exp(Box::new(a.remap(mapping))),
441 Self::Sin(a) => Self::Sin(Box::new(a.remap(mapping))),
442 Self::Cos(a) => Self::Cos(Box::new(a.remap(mapping))),
443 Self::Log(a) => Self::Log(Box::new(a.remap(mapping))),
444 Self::Cis(a) => Self::Cis(Box::new(a.remap(mapping))),
445 }
446 }
447}
448
449impl From<f64> for Expression {
450 fn from(value: f64) -> Self {
451 if value == 0.0 {
452 Self {
453 registry: ExpressionRegistry::default(),
454 tree: ExpressionNode::Zero,
455 }
456 } else if value == 1.0 {
457 Self {
458 registry: ExpressionRegistry::default(),
459 tree: ExpressionNode::One,
460 }
461 } else {
462 Self {
463 registry: ExpressionRegistry::default(),
464 tree: ExpressionNode::Constant(value),
465 }
466 }
467 }
468}
469impl From<&f64> for Expression {
470 fn from(value: &f64) -> Self {
471 (*value).into()
472 }
473}
474impl From<Complex64> for Expression {
475 fn from(value: Complex64) -> Self {
476 if value == Complex64::ZERO {
477 Self {
478 registry: ExpressionRegistry::default(),
479 tree: ExpressionNode::Zero,
480 }
481 } else if value == Complex64::ONE {
482 Self {
483 registry: ExpressionRegistry::default(),
484 tree: ExpressionNode::One,
485 }
486 } else {
487 Self {
488 registry: ExpressionRegistry::default(),
489 tree: ExpressionNode::ComplexConstant(value),
490 }
491 }
492 }
493}
494impl From<&Complex64> for Expression {
495 fn from(value: &Complex64) -> Self {
496 (*value).into()
497 }
498}
499
500impl Expression {
501 pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
503 let registry = ExpressionRegistry::singleton(amplitude)?;
504 Ok(Self {
505 tree: ExpressionNode::Amp(0),
506 registry,
507 })
508 }
509
510 pub fn zero() -> Self {
512 Self {
513 registry: ExpressionRegistry::default(),
514 tree: ExpressionNode::Zero,
515 }
516 }
517
518 pub fn one() -> Self {
520 Self {
521 registry: ExpressionRegistry::default(),
522 tree: ExpressionNode::One,
523 }
524 }
525
526 fn binary_op(
527 a: &Expression,
528 b: &Expression,
529 build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
530 ) -> Expression {
531 let (registry, left_map, right_map) = a
532 .registry
533 .merge(&b.registry)
534 .expect("merging expression registries should not fail");
535 let left_tree = a.tree.remap(&left_map);
536 let right_tree = b.tree.remap(&right_map);
537 Expression {
538 registry,
539 tree: build(Box::new(left_tree), Box::new(right_tree)),
540 }
541 }
542
543 fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
544 Expression {
545 registry: a.registry.clone(),
546 tree: build(Box::new(a.tree.clone())),
547 }
548 }
549
550 pub fn parameters(&self) -> ParameterMap {
552 self.registry.resources.parameters()
553 }
554
555 pub fn n_free(&self) -> usize {
557 self.registry.resources.n_free_parameters()
558 }
559
560 pub fn n_fixed(&self) -> usize {
562 self.registry.resources.n_fixed_parameters()
563 }
564
565 pub fn n_parameters(&self) -> usize {
567 self.registry.resources.n_parameters()
568 }
569
570 pub fn compiled_expression(&self) -> CompiledExpression {
576 let active_amplitudes = vec![true; self.registry.amplitude_use_sites.len()];
577 let amplitude_dependencies = self
578 .registry
579 .amplitude_use_sites
580 .iter()
581 .map(|use_site| {
582 ir::DependenceClass::from(
583 self.registry.amplitudes[use_site.amplitude_index].dependence_hint(),
584 )
585 })
586 .collect::<Vec<_>>();
587 let amplitude_realness = self
588 .registry
589 .amplitude_use_sites
590 .iter()
591 .map(|use_site| self.registry.amplitudes[use_site.amplitude_index].real_valued_hint())
592 .collect::<Vec<_>>();
593 let expression_ir = ir::compile_expression_ir_with_real_hints(
594 &self.tree,
595 &active_amplitudes,
596 &litude_dependencies,
597 &litude_realness,
598 );
599 CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
600 }
601
602 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
604 self.registry.resources.fix_parameter(name, value)
605 }
606
607 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
609 self.registry.resources.free_parameter(name)
610 }
611
612 pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
614 self.registry.resources.rename_parameter(old, new)
615 }
616
617 pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
619 self.registry.resources.rename_parameters(mapping)
620 }
621
622 pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
624 let mut resources = self.registry.resources.clone();
625 let metadata = dataset.metadata();
626 resources.reserve_cache(dataset.n_events_local());
627 resources.refresh_active_indices();
628 let parameter_map = resources.parameter_map.clone();
629 let mut amplitudes: Vec<Box<dyn Amplitude>> = self
630 .registry
631 .amplitudes
632 .iter()
633 .map(|amp| dyn_clone::clone_box(&**amp))
634 .collect();
635 {
636 for amplitude in amplitudes.iter_mut() {
637 amplitude.bind(metadata)?;
638 amplitude.precompute_all(dataset, &mut resources);
639 }
640 }
641 let ir_compile_start = Instant::now();
642 let expression_ir = {
643 let mut active_amplitudes = vec![false; self.registry.amplitude_use_sites.len()];
644 for &index in resources.active_indices() {
645 active_amplitudes[index] = true;
646 }
647 let amplitude_dependencies = self
648 .registry
649 .amplitude_use_sites
650 .iter()
651 .map(|use_site| {
652 ir::DependenceClass::from(
653 amplitudes[use_site.amplitude_index].dependence_hint(),
654 )
655 })
656 .collect::<Vec<_>>();
657 let amplitude_realness = self
658 .registry
659 .amplitude_use_sites
660 .iter()
661 .map(|use_site| amplitudes[use_site.amplitude_index].real_valued_hint())
662 .collect::<Vec<_>>();
663 ir::compile_expression_ir_with_real_hints(
664 &self.tree,
665 &active_amplitudes,
666 &litude_dependencies,
667 &litude_realness,
668 )
669 };
670 let initial_ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
671 let cached_integrals_start = Instant::now();
672 let cached_integrals = Evaluator::precompute_cached_integrals_at_load(
673 &expression_ir,
674 &litudes,
675 &self.registry.amplitude_use_sites,
676 &resources,
677 dataset,
678 parameter_map.free().len(),
679 )?;
680 let initial_cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
681 let lowering_start = Instant::now();
682 let lowered_artifacts = Arc::new(Evaluator::lower_expression_runtime_artifacts(
683 &expression_ir,
684 &cached_integrals,
685 )?);
686 let initial_lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
687 let execution_sets = expression_ir.normalization_execution_sets().clone();
688 let cached_integral_key =
689 Evaluator::cached_integral_cache_key(resources.active.clone(), dataset);
690 let cached_integral_state = Arc::new(CachedIntegralCacheState {
691 key: cached_integral_key.clone(),
692 expression_ir,
693 values: cached_integrals,
694 execution_sets,
695 });
696 let specialization_state = ExpressionSpecializationState {
697 cached_integrals: cached_integral_state.clone(),
698 lowered_artifacts: lowered_artifacts.clone(),
699 };
700 let specialization_cache = HashMap::from([(cached_integral_key, specialization_state)]);
701 let lowered_artifact_cache =
702 HashMap::from([(resources.active.clone(), lowered_artifacts.clone())]);
703 Ok(Evaluator {
704 amplitudes,
705 amplitude_use_sites: self.registry.amplitude_use_sites.clone(),
706 resources: Arc::new(RwLock::new(resources)),
707 dataset: dataset.clone(),
708 expression: self.tree.clone(),
709 ir_planning: ExpressionIrPlanningState {
710 cached_integrals: Arc::new(RwLock::new(Some(cached_integral_state))),
711 specialization_cache: Arc::new(RwLock::new(specialization_cache)),
712 specialization_metrics: Arc::new(RwLock::new(ExpressionSpecializationMetrics {
713 cache_hits: 0,
714 cache_misses: 1,
715 })),
716 lowered_artifact_cache: Arc::new(RwLock::new(lowered_artifact_cache)),
717 active_lowered_artifacts: Arc::new(RwLock::new(Some(lowered_artifacts.clone()))),
718 specialization_status: Arc::new(RwLock::new(Some(
719 ExpressionSpecializationStatus {
720 origin: ExpressionSpecializationOrigin::InitialLoad,
721 },
722 ))),
723 compile_metrics: Arc::new(RwLock::new(ExpressionCompileMetrics {
724 initial_ir_compile_nanos,
725 initial_cached_integrals_nanos,
726 initial_lowering_nanos,
727 specialization_lowering_cache_misses: 1,
728 ..Default::default()
729 })),
730 },
731 registry: self.registry.clone(),
732 })
733 }
734
735 pub fn real(&self) -> Self {
737 Self::unary_op(self, ExpressionNode::Real)
738 }
739 pub fn imag(&self) -> Self {
741 Self::unary_op(self, ExpressionNode::Imag)
742 }
743 pub fn conj(&self) -> Self {
745 Self::unary_op(self, ExpressionNode::Conj)
746 }
747 pub fn norm_sqr(&self) -> Self {
749 Self::unary_op(self, ExpressionNode::NormSqr)
750 }
751 pub fn sqrt(&self) -> Self {
753 Self::unary_op(self, ExpressionNode::Sqrt)
754 }
755 pub fn pow(&self, power: &Expression) -> Self {
757 Self::binary_op(self, power, ExpressionNode::Pow)
758 }
759 pub fn powi(&self, power: i32) -> Self {
761 Self::unary_op(self, |input| ExpressionNode::PowI(input, power))
762 }
763 pub fn powf(&self, power: f64) -> Self {
765 Self::unary_op(self, |input| ExpressionNode::PowF(input, power))
766 }
767 pub fn exp(&self) -> Self {
769 Self::unary_op(self, ExpressionNode::Exp)
770 }
771 pub fn sin(&self) -> Self {
773 Self::unary_op(self, ExpressionNode::Sin)
774 }
775 pub fn cos(&self) -> Self {
777 Self::unary_op(self, ExpressionNode::Cos)
778 }
779 pub fn log(&self) -> Self {
781 Self::unary_op(self, ExpressionNode::Log)
782 }
783 pub fn cis(&self) -> Self {
785 Self::unary_op(self, ExpressionNode::Cis)
786 }
787
788 fn write_tree(
790 &self,
791 t: &ExpressionNode,
792 f: &mut std::fmt::Formatter<'_>,
793 parent_prefix: &str,
794 immediate_prefix: &str,
795 parent_suffix: &str,
796 ) -> std::fmt::Result {
797 let display_string = match t {
798 ExpressionNode::Amp(idx) => {
799 let name = self
800 .registry
801 .amplitude_names
802 .get(*idx)
803 .cloned()
804 .unwrap_or_else(|| "<unregistered>".to_string());
805 format!("{name}(id={idx})")
806 }
807 ExpressionNode::Add(_, _) => "+".to_string(),
808 ExpressionNode::Sub(_, _) => "-".to_string(),
809 ExpressionNode::Mul(_, _) => "×".to_string(),
810 ExpressionNode::Div(_, _) => "÷".to_string(),
811 ExpressionNode::Neg(_) => "-".to_string(),
812 ExpressionNode::Real(_) => "Re".to_string(),
813 ExpressionNode::Imag(_) => "Im".to_string(),
814 ExpressionNode::Conj(_) => "*".to_string(),
815 ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
816 ExpressionNode::Zero => "0 (exact)".to_string(),
817 ExpressionNode::One => "1 (exact)".to_string(),
818 ExpressionNode::Constant(v) => v.to_string(),
819 ExpressionNode::ComplexConstant(v) => v.to_string(),
820 ExpressionNode::Sqrt(_) => "Sqrt".to_string(),
821 ExpressionNode::Pow(_, _) => "Pow".to_string(),
822 ExpressionNode::PowI(_, power) => format!("PowI({power})"),
823 ExpressionNode::PowF(_, power) => format!("PowF({power})"),
824 ExpressionNode::Exp(_) => "Exp".to_string(),
825 ExpressionNode::Sin(_) => "Sin".to_string(),
826 ExpressionNode::Cos(_) => "Cos".to_string(),
827 ExpressionNode::Log(_) => "Log".to_string(),
828 ExpressionNode::Cis(_) => "Cis".to_string(),
829 };
830 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
831 match t {
832 ExpressionNode::Amp(_)
833 | ExpressionNode::Zero
834 | ExpressionNode::One
835 | ExpressionNode::Constant(_)
836 | ExpressionNode::ComplexConstant(_) => {}
837 ExpressionNode::Add(a, b)
838 | ExpressionNode::Sub(a, b)
839 | ExpressionNode::Mul(a, b)
840 | ExpressionNode::Div(a, b)
841 | ExpressionNode::Pow(a, b) => {
842 let terms = [a, b];
843 let mut it = terms.iter().peekable();
844 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
845 while let Some(child) = it.next() {
846 match it.peek() {
847 Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│ "),
848 None => self.write_tree(child, f, &child_prefix, "└─ ", " "),
849 }?;
850 }
851 }
852 ExpressionNode::Neg(a)
853 | ExpressionNode::Real(a)
854 | ExpressionNode::Imag(a)
855 | ExpressionNode::Conj(a)
856 | ExpressionNode::NormSqr(a)
857 | ExpressionNode::Sqrt(a)
858 | ExpressionNode::PowI(a, _)
859 | ExpressionNode::PowF(a, _)
860 | ExpressionNode::Exp(a)
861 | ExpressionNode::Sin(a)
862 | ExpressionNode::Cos(a)
863 | ExpressionNode::Log(a)
864 | ExpressionNode::Cis(a) => {
865 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
866 self.write_tree(a, f, &child_prefix, "└─ ", " ")?;
867 }
868 }
869 Ok(())
870 }
871}
872
873impl Debug for Expression {
874 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875 self.write_tree(&self.tree, f, "", "", "")
876 }
877}
878
879impl Display for Expression {
880 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
881 self.write_tree(&self.tree, f, "", "", "")
882 }
883}
884
885#[rustfmt::skip]
886impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
887 Expression::binary_op(a, b, ExpressionNode::Add)
888});
889#[rustfmt::skip]
890impl_op_ex!(+ |a: &Expression, b: &f64| -> Expression {
891 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
892});
893#[rustfmt::skip]
894impl_op_ex!(+ |a: &f64, b: &Expression| -> Expression {
895 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
896});
897#[rustfmt::skip]
898impl_op_ex!(+ |a: &Expression, b: &Complex64| -> Expression {
899 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
900});
901#[rustfmt::skip]
902impl_op_ex!(+ |a: &Complex64, b: &Expression| -> Expression {
903 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
904});
905
906#[rustfmt::skip]
907impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
908 Expression::binary_op(a, b, ExpressionNode::Sub)
909});
910#[rustfmt::skip]
911impl_op_ex!(- |a: &Expression, b: &f64| -> Expression {
912 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
913});
914#[rustfmt::skip]
915impl_op_ex!(- |a: &f64, b: &Expression| -> Expression {
916 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
917});
918#[rustfmt::skip]
919impl_op_ex!(- |a: &Expression, b: &Complex64| -> Expression {
920 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
921});
922#[rustfmt::skip]
923impl_op_ex!(- |a: &Complex64, b: &Expression| -> Expression {
924 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
925});
926
927#[rustfmt::skip]
928impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
929 Expression::binary_op(a, b, ExpressionNode::Mul)
930});
931#[rustfmt::skip]
932impl_op_ex!(* |a: &Expression, b: &f64| -> Expression {
933 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
934});
935#[rustfmt::skip]
936impl_op_ex!(* |a: &f64, b: &Expression| -> Expression {
937 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
938});
939#[rustfmt::skip]
940impl_op_ex!(* |a: &Expression, b: &Complex64| -> Expression {
941 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
942});
943#[rustfmt::skip]
944impl_op_ex!(* |a: &Complex64, b: &Expression| -> Expression {
945 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
946});
947
948#[rustfmt::skip]
949impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
950 Expression::binary_op(a, b, ExpressionNode::Div)
951});
952#[rustfmt::skip]
953impl_op_ex!(/ |a: &Expression, b: &f64| -> Expression {
954 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
955});
956#[rustfmt::skip]
957impl_op_ex!(/ |a: &f64, b: &Expression| -> Expression {
958 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
959});
960#[rustfmt::skip]
961impl_op_ex!(/ |a: &Expression, b: &Complex64| -> Expression {
962 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
963});
964#[rustfmt::skip]
965impl_op_ex!(/ |a: &Complex64, b: &Expression| -> Expression {
966 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
967});
968
969#[rustfmt::skip]
970impl_op_ex!(- |a: &Expression| -> Expression {
971 Expression::unary_op(a, ExpressionNode::Neg)
972});
973#[derive(Clone, Debug)]
976#[doc(hidden)]
977pub struct ExpressionValueProgramSnapshot {
978 lowered_program: lowered::LoweredProgram,
979}
980
981#[derive(Clone, Debug, PartialEq)]
982pub enum CompiledExpressionNode {
984 Constant(Complex64),
986 Amplitude {
988 index: usize,
990 name: String,
992 },
993 Unary {
995 op: String,
997 input: usize,
999 },
1000 Binary {
1002 op: String,
1004 left: usize,
1006 right: usize,
1008 },
1009}
1010
1011#[derive(Clone, Debug, PartialEq)]
1012pub struct CompiledExpression {
1018 nodes: Vec<CompiledExpressionNode>,
1019 root: usize,
1020}
1021
1022impl CompiledExpression {
1023 fn from_ir(ir: &ir::ExpressionIR, amplitude_names: &[String]) -> Self {
1024 let nodes = ir
1025 .nodes()
1026 .iter()
1027 .map(|node| match node {
1028 ir::IrNode::Constant(value) => CompiledExpressionNode::Constant(*value),
1029 ir::IrNode::Amp(index) => CompiledExpressionNode::Amplitude {
1030 index: *index,
1031 name: amplitude_names
1032 .get(*index)
1033 .cloned()
1034 .unwrap_or_else(|| "<unregistered>".to_string()),
1035 },
1036 ir::IrNode::Unary { op, input } => CompiledExpressionNode::Unary {
1037 op: compiled_unary_op_label(*op),
1038 input: *input,
1039 },
1040 ir::IrNode::Binary { op, left, right } => CompiledExpressionNode::Binary {
1041 op: compiled_binary_op_label(*op),
1042 left: *left,
1043 right: *right,
1044 },
1045 })
1046 .collect();
1047 Self {
1048 nodes,
1049 root: ir.root(),
1050 }
1051 }
1052
1053 pub fn nodes(&self) -> &[CompiledExpressionNode] {
1055 &self.nodes
1056 }
1057
1058 pub fn root(&self) -> usize {
1060 self.root
1061 }
1062
1063 fn node_label(&self, index: usize) -> String {
1064 let Some(node) = self.nodes.get(index) else {
1065 return format!("#{index} <missing>");
1066 };
1067 let label = match node {
1068 CompiledExpressionNode::Constant(value) => format!("const {value}"),
1069 CompiledExpressionNode::Amplitude { index, name } => {
1070 format!("{name}(id={index})")
1071 }
1072 CompiledExpressionNode::Unary { op, .. }
1073 | CompiledExpressionNode::Binary { op, .. } => op.clone(),
1074 };
1075 format!("#{index} {label}")
1076 }
1077
1078 fn write_tree(
1080 &self,
1081 index: usize,
1082 f: &mut std::fmt::Formatter<'_>,
1083 parent_prefix: &str,
1084 immediate_prefix: &str,
1085 parent_suffix: &str,
1086 expanded: &mut [bool],
1087 ) -> std::fmt::Result {
1088 let already_expanded = expanded.get(index).copied().unwrap_or(false);
1089 if let Some(slot) = expanded.get_mut(index) {
1090 *slot = true;
1091 }
1092 let ref_suffix = if already_expanded { " (ref)" } else { "" };
1093 writeln!(
1094 f,
1095 "{}{}{}{}",
1096 parent_prefix,
1097 immediate_prefix,
1098 self.node_label(index),
1099 ref_suffix
1100 )?;
1101 if already_expanded {
1102 return Ok(());
1103 }
1104 let Some(node) = self.nodes.get(index) else {
1105 return Ok(());
1106 };
1107 match node {
1108 CompiledExpressionNode::Constant(_) | CompiledExpressionNode::Amplitude { .. } => {}
1109 CompiledExpressionNode::Unary { input, .. } => {
1110 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1111 self.write_tree(*input, f, &child_prefix, "└─ ", " ", expanded)?;
1112 }
1113 CompiledExpressionNode::Binary { left, right, .. } => {
1114 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1115 self.write_tree(*left, f, &child_prefix, "├─ ", "│ ", expanded)?;
1116 self.write_tree(*right, f, &child_prefix, "└─ ", " ", expanded)?;
1117 }
1118 }
1119 Ok(())
1120 }
1121}
1122
1123impl Display for CompiledExpression {
1124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1125 if self.nodes.is_empty() {
1126 return writeln!(f, "<empty>");
1127 }
1128 let mut expanded = vec![false; self.nodes.len()];
1129 self.write_tree(self.root, f, "", "", "", &mut expanded)
1130 }
1131}
1132
1133fn compiled_unary_op_label(op: ir::IrUnaryOp) -> String {
1134 match op {
1135 ir::IrUnaryOp::Neg => "-".to_string(),
1136 ir::IrUnaryOp::Real => "Re".to_string(),
1137 ir::IrUnaryOp::Imag => "Im".to_string(),
1138 ir::IrUnaryOp::Conj => "*".to_string(),
1139 ir::IrUnaryOp::NormSqr => "NormSqr".to_string(),
1140 ir::IrUnaryOp::Sqrt => "Sqrt".to_string(),
1141 ir::IrUnaryOp::PowI(power) => format!("PowI({power})"),
1142 ir::IrUnaryOp::PowF(bits) => format!("PowF({})", f64::from_bits(bits)),
1143 ir::IrUnaryOp::Exp => "Exp".to_string(),
1144 ir::IrUnaryOp::Sin => "Sin".to_string(),
1145 ir::IrUnaryOp::Cos => "Cos".to_string(),
1146 ir::IrUnaryOp::Log => "Log".to_string(),
1147 ir::IrUnaryOp::Cis => "Cis".to_string(),
1148 }
1149}
1150
1151fn compiled_binary_op_label(op: ir::IrBinaryOp) -> String {
1152 match op {
1153 ir::IrBinaryOp::Add => "+".to_string(),
1154 ir::IrBinaryOp::Sub => "-".to_string(),
1155 ir::IrBinaryOp::Mul => "×".to_string(),
1156 ir::IrBinaryOp::Div => "÷".to_string(),
1157 ir::IrBinaryOp::Pow => "Pow".to_string(),
1158 }
1159}
1160#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1161pub enum ExpressionSpecializationOrigin {
1163 InitialLoad,
1165 CacheMissRebuild,
1167 CacheHitRestore,
1169}
1170#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1171pub struct ExpressionSpecializationStatus {
1173 pub origin: ExpressionSpecializationOrigin,
1175}
1176#[derive(Clone, Debug, PartialEq, Eq)]
1177pub struct ExpressionRuntimeDiagnostics {
1179 pub ir_planning_enabled: bool,
1181 pub lowered_value_program_present: bool,
1183 pub lowered_gradient_program_present: bool,
1185 pub lowered_value_gradient_program_present: bool,
1187 pub cached_parameter_factor_count: usize,
1189 pub lowered_cached_parameter_factor_count: usize,
1191 pub residual_runtime_present: bool,
1193 pub specialization_cache_entries: usize,
1195 pub lowered_artifact_cache_entries: usize,
1197 pub specialization_status: Option<ExpressionSpecializationStatus>,
1199}
1200#[derive(Clone)]
1201struct ExpressionIrPlanningState {
1208 cached_integrals: Arc<RwLock<Option<Arc<CachedIntegralCacheState>>>>,
1209 specialization_cache:
1210 Arc<RwLock<HashMap<CachedIntegralCacheKey, ExpressionSpecializationState>>>,
1211 specialization_metrics: Arc<RwLock<ExpressionSpecializationMetrics>>,
1212 lowered_artifact_cache: Arc<RwLock<HashMap<Vec<bool>, Arc<LoweredArtifactCacheState>>>>,
1213 active_lowered_artifacts: Arc<RwLock<Option<Arc<LoweredArtifactCacheState>>>>,
1214 specialization_status: Arc<RwLock<Option<ExpressionSpecializationStatus>>>,
1215 compile_metrics: Arc<RwLock<ExpressionCompileMetrics>>,
1216}
1217#[allow(missing_docs)]
1219#[derive(Clone)]
1220pub struct Evaluator {
1221 pub amplitudes: Vec<Box<dyn Amplitude>>,
1222 amplitude_use_sites: Vec<AmplitudeUseSite>,
1223 pub resources: Arc<RwLock<Resources>>,
1224 pub dataset: Arc<Dataset>,
1225 pub expression: ExpressionNode,
1226 ir_planning: ExpressionIrPlanningState,
1227 registry: ExpressionRegistry,
1228}
1229
1230#[allow(missing_docs)]
1231impl Evaluator {
1232 pub fn expression_specialization_metrics(&self) -> ExpressionSpecializationMetrics {
1234 *self.ir_planning.specialization_metrics.read()
1235 }
1236 pub fn reset_expression_specialization_metrics(&self) {
1238 *self.ir_planning.specialization_metrics.write() =
1239 ExpressionSpecializationMetrics::default();
1240 }
1241 pub fn expression_compile_metrics(&self) -> ExpressionCompileMetrics {
1243 *self.ir_planning.compile_metrics.read()
1244 }
1245 pub fn expression_runtime_diagnostics(&self) -> ExpressionRuntimeDiagnostics {
1247 let active_artifacts = self.active_lowered_artifacts();
1248 let cached_parameter_factor_count = self
1249 .ir_planning
1250 .cached_integrals
1251 .read()
1252 .as_ref()
1253 .map(|state| state.values.len())
1254 .unwrap_or(0);
1255 let lowered_cached_parameter_factor_count = active_artifacts
1256 .as_ref()
1257 .map(|artifacts| {
1258 artifacts
1259 .lowered_parameter_factors
1260 .iter()
1261 .filter(|factor| factor.is_some())
1262 .count()
1263 })
1264 .unwrap_or(0);
1265 let residual_runtime_present = active_artifacts
1266 .as_ref()
1267 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
1268 .is_some();
1269 ExpressionRuntimeDiagnostics {
1270 ir_planning_enabled: true,
1271 lowered_value_program_present: true,
1272 lowered_gradient_program_present: true,
1273 lowered_value_gradient_program_present: true,
1274 cached_parameter_factor_count,
1275 lowered_cached_parameter_factor_count,
1276 residual_runtime_present,
1277 specialization_cache_entries: self.ir_planning.specialization_cache.read().len(),
1278 lowered_artifact_cache_entries: self.ir_planning.lowered_artifact_cache.read().len(),
1279 specialization_status: *self.ir_planning.specialization_status.read(),
1280 }
1281 }
1282 pub fn reset_expression_compile_metrics(&self) {
1284 let mut metrics = self.ir_planning.compile_metrics.write();
1285 metrics.specialization_cache_hits = 0;
1286 metrics.specialization_cache_misses = 0;
1287 metrics.specialization_ir_compile_nanos = 0;
1288 metrics.specialization_cached_integrals_nanos = 0;
1289 metrics.specialization_lowering_nanos = 0;
1290 metrics.specialization_lowering_cache_hits = 0;
1291 metrics.specialization_lowering_cache_misses = 0;
1292 metrics.specialization_cache_restore_nanos = 0;
1293 }
1294 #[cfg(test)]
1295 fn expression_ir(&self) -> ir::ExpressionIR {
1296 self.ir_planning
1297 .cached_integrals
1298 .read()
1299 .as_ref()
1300 .map(|state| state.expression_ir.clone())
1301 .expect("cached integral state should exist for evaluator IR access")
1302 }
1303 fn lowered_runtime(&self) -> lowered::LoweredExpressionRuntime {
1304 self.active_lowered_artifacts()
1305 .expect("active lowered artifacts should exist for the current specialization")
1306 .lowered_runtime
1307 .clone()
1308 }
1309 fn active_lowered_artifacts(&self) -> Option<Arc<LoweredArtifactCacheState>> {
1310 self.ir_planning.active_lowered_artifacts.read().clone()
1311 }
1312 fn lowered_runtime_slot_count(&self) -> usize {
1313 let runtime = self.lowered_runtime();
1314 [
1315 runtime.value_program().scratch_slots(),
1316 runtime.gradient_program().scratch_slots(),
1317 runtime.value_gradient_program().scratch_slots(),
1318 ]
1319 .into_iter()
1320 .max()
1321 .unwrap_or(0)
1322 }
1323 fn lowered_value_runtime_slot_count(&self) -> usize {
1324 self.lowered_runtime().value_program().scratch_slots()
1325 }
1326
1327 #[doc(hidden)]
1328 pub fn expression_value_program_snapshot(&self) -> ExpressionValueProgramSnapshot {
1329 ExpressionValueProgramSnapshot {
1330 lowered_program: self.lowered_runtime().value_program().clone(),
1331 }
1332 }
1333
1334 #[doc(hidden)]
1335 pub fn expression_value_program_snapshot_for_active_mask(
1336 &self,
1337 active_mask: &[bool],
1338 ) -> LadduResult<ExpressionValueProgramSnapshot> {
1339 let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1340 let lowered_program =
1341 lowered::LoweredProgram::from_ir_value_only(&expression_ir).map_err(|error| {
1342 LadduError::Custom(format!(
1343 "Failed to lower value-only active-mask runtime: {error:?}"
1344 ))
1345 })?;
1346 Ok(ExpressionValueProgramSnapshot { lowered_program })
1347 }
1348
1349 #[doc(hidden)]
1350 pub fn expression_value_program_snapshot_slot_count(
1351 &self,
1352 snapshot: &ExpressionValueProgramSnapshot,
1353 ) -> usize {
1354 let _ = self;
1355 snapshot.lowered_program.scratch_slots()
1356 }
1357
1358 pub fn compiled_expression(&self) -> CompiledExpression {
1361 let expression_ir = self.compile_expression_ir_for_active_mask(&self.active_mask());
1362 CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
1363 }
1364
1365 pub fn expression(&self) -> Expression {
1367 Expression {
1368 tree: self.expression.clone(),
1369 registry: self.registry.clone(),
1370 }
1371 }
1372 fn lowered_gradient_runtime_slot_count(&self) -> usize {
1373 self.lowered_runtime().gradient_program().scratch_slots()
1374 }
1375 fn lowered_value_gradient_runtime_slot_count(&self) -> usize {
1376 self.lowered_runtime()
1377 .value_gradient_program()
1378 .scratch_slots()
1379 }
1380
1381 fn expression_value_slot_count(&self) -> usize {
1382 self.lowered_value_runtime_slot_count()
1383 }
1384 fn expression_gradient_slot_count(&self) -> usize {
1385 self.lowered_gradient_runtime_slot_count()
1386 }
1387 fn expression_value_gradient_slot_count(&self) -> usize {
1388 self.lowered_value_gradient_runtime_slot_count()
1389 }
1390
1391 #[doc(hidden)]
1392 pub fn expression_value_gradient_slot_count_public(&self) -> usize {
1393 self.expression_value_gradient_slot_count()
1394 }
1395 #[cfg(test)]
1396 fn specialization_cache_len(&self) -> usize {
1397 self.ir_planning.specialization_cache.read().len()
1398 }
1399 #[cfg(test)]
1400 fn lowered_artifact_cache_len(&self) -> usize {
1401 self.ir_planning.lowered_artifact_cache.read().len()
1402 }
1403 fn install_expression_specialization(&self, specialization: &ExpressionSpecializationState) {
1404 debug_assert!(Self::lowered_artifact_signature_matches(
1405 &specialization.lowered_artifacts,
1406 &specialization.cached_integrals.values,
1407 ));
1408 *self.ir_planning.cached_integrals.write() = Some(specialization.cached_integrals.clone());
1409 *self.ir_planning.active_lowered_artifacts.write() =
1410 Some(specialization.lowered_artifacts.clone());
1411 debug_assert_eq!(
1412 self.active_lowered_artifacts()
1413 .as_ref()
1414 .map(|artifacts| Arc::ptr_eq(artifacts, &specialization.lowered_artifacts)),
1415 Some(true)
1416 );
1417 debug_assert_eq!(
1418 self.lowered_runtime().value_program().scratch_slots(),
1419 specialization
1420 .lowered_artifacts
1421 .lowered_runtime
1422 .value_program()
1423 .scratch_slots()
1424 );
1425 }
1426 fn lower_expression_runtime_artifacts(
1427 expression_ir: &ir::ExpressionIR,
1428 values: &[PrecomputedCachedIntegral],
1429 ) -> LadduResult<LoweredArtifactCacheState> {
1430 let parameter_node_indices = values
1431 .iter()
1432 .map(|value| value.parameter_node_index)
1433 .collect();
1434 let mul_node_indices = values.iter().map(|value| value.mul_node_index).collect();
1435 let lowered_parameter_factors = Self::lower_cached_parameter_factors(expression_ir);
1436 let residual_runtime = Self::lower_residual_runtime(expression_ir, values);
1437 let lowered_runtime = lowered::LoweredExpressionRuntime::from_ir_value_gradient(
1438 expression_ir,
1439 )
1440 .map_err(|error| {
1441 LadduError::Custom(format!(
1442 "Failed to lower expression runtime for specialized IR: {error:?}"
1443 ))
1444 })?;
1445 Ok(LoweredArtifactCacheState {
1446 parameter_node_indices,
1447 mul_node_indices,
1448 lowered_parameter_factors,
1449 residual_runtime,
1450 lowered_runtime,
1451 })
1452 }
1453 fn lowered_artifact_signature_matches(
1454 artifacts: &LoweredArtifactCacheState,
1455 values: &[PrecomputedCachedIntegral],
1456 ) -> bool {
1457 artifacts.parameter_node_indices.len() == values.len()
1458 && artifacts.mul_node_indices.len() == values.len()
1459 && artifacts
1460 .parameter_node_indices
1461 .iter()
1462 .copied()
1463 .eq(values.iter().map(|value| value.parameter_node_index))
1464 && artifacts
1465 .mul_node_indices
1466 .iter()
1467 .copied()
1468 .eq(values.iter().map(|value| value.mul_node_index))
1469 }
1470 fn build_expression_specialization(
1471 &self,
1472 resources: &Resources,
1473 key: CachedIntegralCacheKey,
1474 ) -> LadduResult<ExpressionSpecializationState> {
1475 let ir_compile_start = Instant::now();
1476 let expression_ir = self.compile_expression_ir_for_active_mask(&resources.active);
1477 let ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
1478 let cached_integrals_start = Instant::now();
1479 let values = Self::precompute_cached_integrals_at_load(
1480 &expression_ir,
1481 &self.amplitudes,
1482 &self.amplitude_use_sites,
1483 resources,
1484 &self.dataset,
1485 self.resources.read().n_free_parameters(),
1486 )?;
1487 let cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
1488 let execution_sets = expression_ir.normalization_execution_sets().clone();
1489 let active_mask_key = resources.active.clone();
1490 let cached_lowered_artifacts = {
1491 let lowered_artifact_cache = self.ir_planning.lowered_artifact_cache.read();
1492 lowered_artifact_cache
1493 .get(&active_mask_key)
1494 .cloned()
1495 .filter(|artifacts| Self::lowered_artifact_signature_matches(artifacts, &values))
1496 };
1497 let lowered_artifacts = if let Some(artifacts) = cached_lowered_artifacts {
1498 self.ir_planning
1499 .compile_metrics
1500 .write()
1501 .specialization_lowering_cache_hits += 1;
1502 artifacts
1503 } else {
1504 let lowering_start = Instant::now();
1505 let artifacts = Arc::new(
1506 Self::lower_expression_runtime_artifacts(&expression_ir, &values)
1507 .expect("specialized lowered runtime should build"),
1508 );
1509 let lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
1510 self.ir_planning
1511 .lowered_artifact_cache
1512 .write()
1513 .insert(active_mask_key, artifacts.clone());
1514 let mut compile_metrics = self.ir_planning.compile_metrics.write();
1515 compile_metrics.specialization_lowering_cache_misses += 1;
1516 compile_metrics.specialization_lowering_nanos += lowering_nanos;
1517 artifacts
1518 };
1519 let mut compile_metrics = self.ir_planning.compile_metrics.write();
1520 compile_metrics.specialization_cache_misses += 1;
1521 compile_metrics.specialization_ir_compile_nanos += ir_compile_nanos;
1522 compile_metrics.specialization_cached_integrals_nanos += cached_integrals_nanos;
1523 Ok(ExpressionSpecializationState {
1524 cached_integrals: Arc::new(CachedIntegralCacheState {
1525 key,
1526 expression_ir,
1527 values,
1528 execution_sets,
1529 }),
1530 lowered_artifacts,
1531 })
1532 }
1533 fn ensure_expression_specialization(
1534 &self,
1535 resources: &Resources,
1536 ) -> LadduResult<ExpressionSpecializationState> {
1537 let key = Self::cached_integral_cache_key(resources.active.clone(), &self.dataset);
1538 if let Some(state) = self.ir_planning.cached_integrals.read().as_ref() {
1539 if state.key == key {
1540 return Ok(ExpressionSpecializationState {
1541 cached_integrals: state.clone(),
1542 lowered_artifacts: self
1543 .active_lowered_artifacts()
1544 .expect("active lowered artifacts should exist for cached specialization"),
1545 });
1546 }
1547 }
1548 let cached_specialization = {
1549 let specialization_cache = self.ir_planning.specialization_cache.read();
1550 specialization_cache.get(&key).cloned()
1551 };
1552 if let Some(specialization) = cached_specialization {
1553 let restore_start = Instant::now();
1554 self.ir_planning.specialization_metrics.write().cache_hits += 1;
1555 self.install_expression_specialization(&specialization);
1556 *self.ir_planning.specialization_status.write() =
1557 Some(ExpressionSpecializationStatus {
1558 origin: ExpressionSpecializationOrigin::CacheHitRestore,
1559 });
1560 let restore_nanos = restore_start.elapsed().as_nanos() as u64;
1561 let mut compile_metrics = self.ir_planning.compile_metrics.write();
1562 compile_metrics.specialization_cache_hits += 1;
1563 compile_metrics.specialization_cache_restore_nanos += restore_nanos;
1564 return Ok(specialization);
1565 }
1566 let specialization = self.build_expression_specialization(resources, key.clone())?;
1567 self.ir_planning.specialization_metrics.write().cache_misses += 1;
1568 self.ir_planning
1569 .specialization_cache
1570 .write()
1571 .insert(key, specialization.clone());
1572 self.install_expression_specialization(&specialization);
1573 let origin = if self.ir_planning.specialization_cache.read().len() == 1 {
1574 ExpressionSpecializationOrigin::InitialLoad
1575 } else {
1576 ExpressionSpecializationOrigin::CacheMissRebuild
1577 };
1578 *self.ir_planning.specialization_status.write() =
1579 Some(ExpressionSpecializationStatus { origin });
1580 Ok(specialization)
1581 }
1582 fn rebuild_runtime_specializations(&self, resources: &Resources) {
1583 let _ = self.ensure_expression_specialization(resources);
1584 }
1585 fn refresh_runtime_specializations(&self) {
1586 let resources = self.resources.read();
1587 self.rebuild_runtime_specializations(&resources);
1588 }
1589 fn cached_integral_cache_key(
1590 active_mask: Vec<bool>,
1591 dataset: &Dataset,
1592 ) -> CachedIntegralCacheKey {
1593 let (weights_ptr, weights_local_len) = dataset.local_weight_cache_key();
1594 CachedIntegralCacheKey {
1595 active_mask,
1596 n_events_local: dataset.n_events_local(),
1597 weights_local_len,
1598 weighted_sum_bits: dataset.n_events_weighted_local().to_bits(),
1599 weights_ptr,
1600 }
1601 }
1602 fn precompute_cached_integrals_at_load(
1603 expression_ir: &ir::ExpressionIR,
1604 amplitudes: &[Box<dyn Amplitude>],
1605 amplitude_use_sites: &[AmplitudeUseSite],
1606 resources: &Resources,
1607 dataset: &Dataset,
1608 n_free_parameters: usize,
1609 ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
1610 let descriptors = expression_ir.cached_integral_descriptors();
1611 if descriptors.is_empty() {
1612 return Ok(Vec::new());
1613 }
1614 let execution_sets = expression_ir.normalization_execution_sets();
1615 let seed_parameters = vec![0.0; n_free_parameters];
1616 let parameters = resources.parameter_map.assemble(&seed_parameters)?;
1617 let mut amplitude_values = vec![Complex64::ZERO; amplitude_use_sites.len()];
1618 let mut compute_values = vec![Complex64::ZERO; amplitudes.len()];
1619 let mut value_slots = vec![Complex64::ZERO; expression_ir.node_count()];
1620 let active_set = resources.active_indices();
1621 let cache_active_indices = execution_sets
1622 .cached_cache_amplitudes
1623 .iter()
1624 .copied()
1625 .filter(|index| active_set.binary_search(index).is_ok())
1626 .collect::<Vec<_>>();
1627 let mut weighted_cache_sums = vec![Complex64::ZERO; descriptors.len()];
1628 for (cache, event) in resources.caches.iter().zip(dataset.weights_local().iter()) {
1629 amplitude_values.fill(Complex64::ZERO);
1630 compute_values.fill(Complex64::ZERO);
1631 let mut computed = vec![false; amplitudes.len()];
1632 for &use_site_idx in &cache_active_indices {
1633 let amp_idx = amplitude_use_sites[use_site_idx].amplitude_index;
1634 if !computed[amp_idx] {
1635 compute_values[amp_idx] = amplitudes[amp_idx].compute(¶meters, cache);
1636 computed[amp_idx] = true;
1637 }
1638 amplitude_values[use_site_idx] = compute_values[amp_idx];
1639 }
1640 expression_ir.evaluate_into(&litude_values, &mut value_slots);
1641 let weight = *event;
1642 for (descriptor_index, descriptor) in descriptors.iter().enumerate() {
1643 weighted_cache_sums[descriptor_index] +=
1644 value_slots[descriptor.cache_node_index] * weight;
1645 }
1646 }
1647 Ok(descriptors
1648 .iter()
1649 .zip(weighted_cache_sums)
1650 .map(
1651 |(descriptor, weighted_cache_sum)| PrecomputedCachedIntegral {
1652 mul_node_index: descriptor.mul_node_index,
1653 parameter_node_index: descriptor.parameter_node_index,
1654 cache_node_index: descriptor.cache_node_index,
1655 coefficient: descriptor.coefficient,
1656 weighted_cache_sum,
1657 },
1658 )
1659 .collect())
1660 }
1661 fn lower_cached_parameter_factors(
1662 expression_ir: &ir::ExpressionIR,
1663 ) -> Vec<Option<lowered::LoweredFactorRuntime>> {
1664 expression_ir
1665 .cached_integral_descriptors()
1666 .iter()
1667 .map(|descriptor| {
1668 lowered::LoweredFactorRuntime::from_ir_root_value_gradient(
1669 expression_ir,
1670 descriptor.parameter_node_index,
1671 )
1672 .ok()
1673 })
1674 .collect()
1675 }
1676 fn lower_residual_runtime(
1677 expression_ir: &ir::ExpressionIR,
1678 descriptors: &[PrecomputedCachedIntegral],
1679 ) -> Option<lowered::LoweredExpressionRuntime> {
1680 let mut zeroed_nodes = vec![false; expression_ir.node_count()];
1681 for descriptor in descriptors {
1682 if descriptor.mul_node_index < zeroed_nodes.len() {
1683 zeroed_nodes[descriptor.mul_node_index] = true;
1684 }
1685 }
1686 lowered::LoweredExpressionRuntime::from_ir_zeroed_value_gradient(
1687 expression_ir,
1688 &zeroed_nodes,
1689 )
1690 .ok()
1691 }
1692
1693 #[inline]
1694 fn fill_amplitude_values(
1695 &self,
1696 amplitude_values: &mut [Complex64],
1697 active_indices: &[usize],
1698 parameters: &Parameters,
1699 cache: &Cache,
1700 ) {
1701 amplitude_values.fill(Complex64::ZERO);
1702 let mut compute_values = vec![Complex64::ZERO; self.amplitudes.len()];
1703 let mut computed = vec![false; self.amplitudes.len()];
1704 for &use_site_idx in active_indices {
1705 let amp_idx = self.amplitude_use_sites[use_site_idx].amplitude_index;
1706 if !computed[amp_idx] {
1707 compute_values[amp_idx] = self.amplitudes[amp_idx].compute(parameters, cache);
1708 computed[amp_idx] = true;
1709 }
1710 amplitude_values[use_site_idx] = compute_values[amp_idx];
1711 }
1712 }
1713
1714 #[inline]
1715 fn fill_amplitude_gradients(
1716 &self,
1717 gradient_values: &mut [DVector<Complex64>],
1718 active_mask: &[bool],
1719 parameters: &Parameters,
1720 cache: &Cache,
1721 ) {
1722 let mut compute_gradients = vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1723 let mut computed = vec![false; self.amplitudes.len()];
1724 for ((use_site, active), grad) in self
1725 .amplitude_use_sites
1726 .iter()
1727 .zip(active_mask.iter())
1728 .zip(gradient_values.iter_mut())
1729 {
1730 grad.fill(Complex64::ZERO);
1731 if *active {
1732 let amp_idx = use_site.amplitude_index;
1733 if !computed[amp_idx] {
1734 self.amplitudes[amp_idx].compute_gradient(
1735 parameters,
1736 cache,
1737 &mut compute_gradients[amp_idx],
1738 );
1739 computed[amp_idx] = true;
1740 }
1741 grad.copy_from(&compute_gradients[amp_idx]);
1742 }
1743 }
1744 }
1745
1746 #[inline]
1747 fn fill_amplitude_values_and_gradients(
1748 &self,
1749 amplitude_values: &mut [Complex64],
1750 gradient_values: &mut [DVector<Complex64>],
1751 active_indices: &[usize],
1752 active_mask: &[bool],
1753 parameters: &Parameters,
1754 cache: &Cache,
1755 ) {
1756 self.fill_amplitude_values(amplitude_values, active_indices, parameters, cache);
1757 self.fill_amplitude_gradients(gradient_values, active_mask, parameters, cache);
1758 }
1759
1760 #[doc(hidden)]
1761 pub fn fill_amplitude_values_and_gradients_public(
1762 &self,
1763 amplitude_values: &mut [Complex64],
1764 gradient_values: &mut [DVector<Complex64>],
1765 active_indices: &[usize],
1766 active_mask: &[bool],
1767 parameters: &Parameters,
1768 cache: &Cache,
1769 ) {
1770 self.fill_amplitude_values_and_gradients(
1771 amplitude_values,
1772 gradient_values,
1773 active_indices,
1774 active_mask,
1775 parameters,
1776 cache,
1777 );
1778 }
1779
1780 #[cfg(feature = "execution-context-prototype")]
1781 #[inline]
1782 fn evaluate_cache_gradient_with_scratch(
1783 &self,
1784 amplitude_values: &mut [Complex64],
1785 gradient_values: &mut [DVector<Complex64>],
1786 value_slots: &mut [Complex64],
1787 gradient_slots: &mut [DVector<Complex64>],
1788 active_indices: &[usize],
1789 active_mask: &[bool],
1790 parameters: &Parameters,
1791 cache: &Cache,
1792 ) -> DVector<Complex64> {
1793 self.fill_amplitude_values_and_gradients(
1794 amplitude_values,
1795 gradient_values,
1796 active_indices,
1797 active_mask,
1798 parameters,
1799 cache,
1800 );
1801 self.evaluate_expression_gradient_with_scratch(
1802 amplitude_values,
1803 gradient_values,
1804 value_slots,
1805 gradient_slots,
1806 )
1807 }
1808
1809 #[cfg(feature = "execution-context-prototype")]
1810 #[allow(dead_code)]
1811 #[inline]
1812 fn evaluate_cache_value_gradient_with_scratch(
1813 &self,
1814 amplitude_values: &mut [Complex64],
1815 gradient_values: &mut [DVector<Complex64>],
1816 value_slots: &mut [Complex64],
1817 gradient_slots: &mut [DVector<Complex64>],
1818 active_indices: &[usize],
1819 active_mask: &[bool],
1820 parameters: &Parameters,
1821 cache: &Cache,
1822 ) -> (Complex64, DVector<Complex64>) {
1823 self.fill_amplitude_values_and_gradients(
1824 amplitude_values,
1825 gradient_values,
1826 active_indices,
1827 active_mask,
1828 parameters,
1829 cache,
1830 );
1831 self.evaluate_expression_value_gradient_with_scratch(
1832 amplitude_values,
1833 gradient_values,
1834 value_slots,
1835 gradient_slots,
1836 )
1837 }
1838
1839 pub fn expression_slot_count(&self) -> usize {
1840 self.lowered_runtime_slot_count()
1841 }
1842 fn compile_expression_ir_for_active_mask(&self, active_mask: &[bool]) -> ir::ExpressionIR {
1843 let amplitude_dependencies = self
1844 .amplitude_use_sites
1845 .iter()
1846 .map(|use_site| {
1847 ir::DependenceClass::from(
1848 self.amplitudes[use_site.amplitude_index].dependence_hint(),
1849 )
1850 })
1851 .collect::<Vec<_>>();
1852 let amplitude_realness = self
1853 .amplitude_use_sites
1854 .iter()
1855 .map(|use_site| self.amplitudes[use_site.amplitude_index].real_valued_hint())
1856 .collect::<Vec<_>>();
1857 ir::compile_expression_ir_with_real_hints(
1858 &self.expression,
1859 active_mask,
1860 &litude_dependencies,
1861 &litude_realness,
1862 )
1863 }
1864 fn lower_expression_runtime_for_active_mask(
1865 &self,
1866 active_mask: &[bool],
1867 ) -> LadduResult<lowered::LoweredExpressionRuntime> {
1868 let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1869 lowered::LoweredExpressionRuntime::from_ir_value_gradient(&expression_ir).map_err(|error| {
1870 LadduError::Custom(format!(
1871 "Failed to lower active-mask runtime specialization: {error:?}"
1872 ))
1873 })
1874 }
1875 fn ensure_cached_integral_cache_state(
1876 &self,
1877 resources: &Resources,
1878 ) -> LadduResult<Arc<CachedIntegralCacheState>> {
1879 Ok(self
1880 .ensure_expression_specialization(resources)?
1881 .cached_integrals)
1882 }
1883
1884 fn evaluate_expression_runtime_value_with_scratch(
1885 &self,
1886 amplitude_values: &[Complex64],
1887 scratch: &mut [Complex64],
1888 ) -> Complex64 {
1889 let lowered_runtime = self.lowered_runtime();
1890 lowered_runtime
1891 .value_program()
1892 .evaluate_into(amplitude_values, scratch)
1893 }
1894
1895 #[doc(hidden)]
1896 pub fn evaluate_expression_value_with_program_snapshot(
1897 &self,
1898 program_snapshot: &ExpressionValueProgramSnapshot,
1899 amplitude_values: &[Complex64],
1900 scratch: &mut [Complex64],
1901 ) -> Complex64 {
1902 program_snapshot
1903 .lowered_program
1904 .evaluate_into(amplitude_values, scratch)
1905 }
1906
1907 fn evaluate_expression_runtime_gradient_with_scratch(
1908 &self,
1909 amplitude_values: &[Complex64],
1910 gradient_values: &[DVector<Complex64>],
1911 value_scratch: &mut [Complex64],
1912 gradient_scratch: &mut [DVector<Complex64>],
1913 ) -> DVector<Complex64> {
1914 let lowered_runtime = self.lowered_runtime();
1915 lowered_runtime.gradient_program().evaluate_gradient_into(
1916 amplitude_values,
1917 gradient_values,
1918 value_scratch,
1919 gradient_scratch,
1920 )
1921 }
1922
1923 fn evaluate_expression_runtime_value_gradient_with_scratch(
1924 &self,
1925 amplitude_values: &[Complex64],
1926 gradient_values: &[DVector<Complex64>],
1927 value_scratch: &mut [Complex64],
1928 gradient_scratch: &mut [DVector<Complex64>],
1929 ) -> (Complex64, DVector<Complex64>) {
1930 let lowered_runtime = self.lowered_runtime();
1931 lowered_runtime
1932 .value_gradient_program()
1933 .evaluate_value_gradient_into(
1934 amplitude_values,
1935 gradient_values,
1936 value_scratch,
1937 gradient_scratch,
1938 )
1939 }
1940
1941 fn evaluate_expression_runtime_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
1942 let lowered_runtime = self.lowered_runtime();
1943 let program = lowered_runtime.value_program();
1944 let mut scratch = vec![Complex64::ZERO; program.scratch_slots()];
1945 program.evaluate_into(amplitude_values, &mut scratch)
1946 }
1947
1948 fn evaluate_expression_runtime_gradient(
1949 &self,
1950 amplitude_values: &[Complex64],
1951 gradient_values: &[DVector<Complex64>],
1952 ) -> DVector<Complex64> {
1953 let lowered_runtime = self.lowered_runtime();
1954 let program = lowered_runtime.gradient_program();
1955 let mut value_scratch = vec![Complex64::ZERO; program.scratch_slots()];
1956 let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
1957 let mut gradient_scratch = vec![Complex64::ZERO; program.scratch_slots() * grad_dim];
1958 program.evaluate_gradient_into_flat(
1959 amplitude_values,
1960 gradient_values,
1961 &mut value_scratch,
1962 &mut gradient_scratch,
1963 grad_dim,
1964 )
1965 }
1966 pub fn expression_root_dependence(&self) -> LadduResult<ExpressionDependence> {
1968 let resources = self.resources.read();
1969 Ok(self
1970 .ensure_cached_integral_cache_state(&resources)?
1971 .expression_ir
1972 .root_dependence()
1973 .into())
1974 }
1975 pub fn expression_node_dependence_annotations(&self) -> LadduResult<Vec<ExpressionDependence>> {
1977 let resources = self.resources.read();
1978 Ok(self
1979 .ensure_cached_integral_cache_state(&resources)?
1980 .expression_ir
1981 .node_dependence_annotations()
1982 .iter()
1983 .copied()
1984 .map(Into::into)
1985 .collect())
1986 }
1987 pub fn expression_dependence_warnings(&self) -> LadduResult<Vec<String>> {
1989 let resources = self.resources.read();
1990 Ok(self
1991 .ensure_cached_integral_cache_state(&resources)?
1992 .expression_ir
1993 .dependence_warnings()
1994 .to_vec())
1995 }
1996 pub fn expression_normalization_plan_explain(&self) -> LadduResult<NormalizationPlanExplain> {
1998 let resources = self.resources.read();
1999 Ok(self
2000 .ensure_cached_integral_cache_state(&resources)?
2001 .expression_ir
2002 .normalization_plan_explain()
2003 .into())
2004 }
2005 pub fn expression_normalization_execution_sets(
2007 &self,
2008 ) -> LadduResult<NormalizationExecutionSetsExplain> {
2009 let resources = self.resources.read();
2010 Ok(self
2011 .ensure_cached_integral_cache_state(&resources)?
2012 .execution_sets
2013 .clone()
2014 .into())
2015 }
2016 pub fn expression_precomputed_cached_integrals(
2018 &self,
2019 ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
2020 let resources = self.resources.read();
2021 Ok(self
2022 .ensure_cached_integral_cache_state(&resources)?
2023 .values
2024 .clone())
2025 }
2026 pub fn expression_precomputed_cached_integral_gradient_terms(
2031 &self,
2032 parameters: &[f64],
2033 ) -> LadduResult<Vec<PrecomputedCachedIntegralGradientTerm>> {
2034 let resources = self.resources.read();
2035 let state = self.ensure_cached_integral_cache_state(&resources)?;
2036 if state.values.is_empty() {
2037 return Ok(Vec::new());
2038 }
2039
2040 let Some(cache) = resources.caches.first() else {
2041 return Ok(state
2042 .values
2043 .iter()
2044 .map(|descriptor| PrecomputedCachedIntegralGradientTerm {
2045 mul_node_index: descriptor.mul_node_index,
2046 parameter_node_index: descriptor.parameter_node_index,
2047 cache_node_index: descriptor.cache_node_index,
2048 coefficient: descriptor.coefficient,
2049 weighted_gradient: DVector::zeros(parameters.len()),
2050 })
2051 .collect());
2052 };
2053
2054 let parameter_values = resources.parameter_map.assemble(parameters)?;
2055 let mut amplitude_values = vec![Complex64::ZERO; self.amplitude_use_sites.len()];
2056 self.fill_amplitude_values(
2057 &mut amplitude_values,
2058 resources.active_indices(),
2059 ¶meter_values,
2060 cache,
2061 );
2062 let mut amplitude_gradients = (0..self.amplitude_use_sites.len())
2063 .map(|_| DVector::zeros(parameters.len()))
2064 .collect::<Vec<_>>();
2065 self.fill_amplitude_gradients(
2066 &mut amplitude_gradients,
2067 &resources.active,
2068 ¶meter_values,
2069 cache,
2070 );
2071 let lowered_artifacts = self.active_lowered_artifacts();
2072 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2073 let mut gradient_slots = (0..state.expression_ir.node_count())
2074 .map(|_| DVector::zeros(parameters.len()))
2075 .collect::<Vec<_>>();
2076 let max_lowered_slots = lowered_artifacts
2077 .as_ref()
2078 .map(|artifacts| {
2079 artifacts
2080 .lowered_parameter_factors
2081 .iter()
2082 .filter_map(|runtime| {
2083 runtime
2084 .as_ref()
2085 .and_then(|runtime| runtime.gradient_program())
2086 .map(|program| program.scratch_slots())
2087 })
2088 .max()
2089 .unwrap_or(0)
2090 })
2091 .unwrap_or(0);
2092 let mut lowered_value_slots = vec![Complex64::ZERO; max_lowered_slots];
2093 let mut lowered_gradient_slots = vec![DVector::zeros(parameters.len()); max_lowered_slots];
2094 let use_lowered = lowered_artifacts.as_ref().is_some_and(|artifacts| {
2095 artifacts.lowered_parameter_factors.len() == state.values.len()
2096 && artifacts.lowered_parameter_factors.iter().all(|runtime| {
2097 runtime
2098 .as_ref()
2099 .and_then(|runtime| runtime.gradient_program())
2100 .is_some()
2101 })
2102 });
2103
2104 if !use_lowered {
2105 let _ = state.expression_ir.evaluate_gradient_into(
2106 &litude_values,
2107 &litude_gradients,
2108 &mut value_slots,
2109 &mut gradient_slots,
2110 );
2111 }
2112
2113 if use_lowered {
2114 let lowered_artifacts = lowered_artifacts.expect("lowered artifacts should exist");
2115 Ok(state
2116 .values
2117 .iter()
2118 .cloned()
2119 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2120 .map(|(descriptor, runtime)| {
2121 let parameter_gradient = runtime
2122 .as_ref()
2123 .and_then(|runtime| runtime.gradient_program())
2124 .map(|program| {
2125 program.evaluate_gradient_into(
2126 &litude_values,
2127 &litude_gradients,
2128 &mut lowered_value_slots[..program.scratch_slots()],
2129 &mut lowered_gradient_slots[..program.scratch_slots()],
2130 )
2131 })
2132 .unwrap_or_else(|| DVector::zeros(parameters.len()));
2133 let weighted_gradient = parameter_gradient.map(|value| {
2134 value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2135 });
2136 PrecomputedCachedIntegralGradientTerm {
2137 mul_node_index: descriptor.mul_node_index,
2138 parameter_node_index: descriptor.parameter_node_index,
2139 cache_node_index: descriptor.cache_node_index,
2140 coefficient: descriptor.coefficient,
2141 weighted_gradient,
2142 }
2143 })
2144 .collect())
2145 } else {
2146 Ok(state
2147 .values
2148 .iter()
2149 .map(|descriptor| {
2150 let parameter_gradient = gradient_slots
2151 .get(descriptor.parameter_node_index)
2152 .cloned()
2153 .unwrap_or_else(|| DVector::zeros(parameters.len()));
2154 let weighted_gradient = parameter_gradient.map(|value| {
2155 value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2156 });
2157 PrecomputedCachedIntegralGradientTerm {
2158 mul_node_index: descriptor.mul_node_index,
2159 parameter_node_index: descriptor.parameter_node_index,
2160 cache_node_index: descriptor.cache_node_index,
2161 coefficient: descriptor.coefficient,
2162 weighted_gradient,
2163 }
2164 })
2165 .collect())
2166 }
2167 }
2168 fn evaluate_cached_weighted_value_sum_ir(
2169 &self,
2170 state: &CachedIntegralCacheState,
2171 amplitude_values: &[Complex64],
2172 ) -> f64 {
2173 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2174 let _ = state
2175 .expression_ir
2176 .evaluate_into(amplitude_values, &mut value_slots);
2177 state
2178 .values
2179 .iter()
2180 .map(|descriptor| {
2181 let parameter_factor = value_slots[descriptor.parameter_node_index];
2182 (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2183 .re
2184 })
2185 .sum()
2186 }
2187 fn evaluate_cached_weighted_value_sum_lowered(
2188 &self,
2189 state: &CachedIntegralCacheState,
2190 lowered_artifacts: &LoweredArtifactCacheState,
2191 amplitude_values: &[Complex64],
2192 ) -> Option<f64> {
2193 let max_slots = lowered_artifacts
2194 .lowered_parameter_factors
2195 .iter()
2196 .filter_map(|runtime| {
2197 runtime
2198 .as_ref()
2199 .and_then(|runtime| runtime.value_program())
2200 .map(|program| program.scratch_slots())
2201 })
2202 .max()
2203 .unwrap_or(0);
2204 let mut value_slots = vec![Complex64::ZERO; max_slots];
2205 let mut total = 0.0;
2206 for (descriptor, runtime) in state
2207 .values
2208 .iter()
2209 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2210 {
2211 let parameter_factor = runtime
2212 .as_ref()
2213 .and_then(|runtime| runtime.value_program())
2214 .map(|program| {
2215 program.evaluate_into(
2216 amplitude_values,
2217 &mut value_slots[..program.scratch_slots()],
2218 )
2219 })?;
2220 total +=
2221 (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2222 .re;
2223 }
2224 Some(total)
2225 }
2226 fn evaluate_cached_weighted_gradient_sum_ir(
2227 &self,
2228 state: &CachedIntegralCacheState,
2229 amplitude_values: &[Complex64],
2230 amplitude_gradients: &[DVector<Complex64>],
2231 grad_dim: usize,
2232 ) -> DVector<f64> {
2233 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2234 let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2235 let _ = state.expression_ir.evaluate_gradient_into(
2236 amplitude_values,
2237 amplitude_gradients,
2238 &mut value_slots,
2239 &mut gradient_slots,
2240 );
2241 state
2242 .values
2243 .iter()
2244 .fold(DVector::zeros(grad_dim), |mut accum, descriptor| {
2245 let parameter_gradient = &gradient_slots[descriptor.parameter_node_index];
2246 let coefficient = descriptor.coefficient as f64;
2247 for (accum_item, gradient_item) in accum.iter_mut().zip(parameter_gradient.iter()) {
2248 *accum_item +=
2249 (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2250 }
2251 accum
2252 })
2253 }
2254 fn evaluate_cached_weighted_gradient_sum_lowered(
2255 &self,
2256 state: &CachedIntegralCacheState,
2257 lowered_artifacts: &LoweredArtifactCacheState,
2258 amplitude_values: &[Complex64],
2259 amplitude_gradients: &[DVector<Complex64>],
2260 grad_dim: usize,
2261 ) -> Option<DVector<f64>> {
2262 let max_value_slots = lowered_artifacts
2263 .lowered_parameter_factors
2264 .iter()
2265 .filter_map(|runtime| {
2266 runtime
2267 .as_ref()
2268 .and_then(|runtime| runtime.gradient_program())
2269 .map(|program| program.scratch_slots())
2270 })
2271 .max()
2272 .unwrap_or(0);
2273 let mut value_slots = vec![Complex64::ZERO; max_value_slots];
2274 let mut gradient_slots = vec![Complex64::ZERO; max_value_slots * grad_dim];
2275 let mut total = DVector::zeros(grad_dim);
2276 for (descriptor, runtime) in state
2277 .values
2278 .iter()
2279 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2280 {
2281 let parameter_gradient = runtime
2282 .as_ref()
2283 .and_then(|runtime| runtime.gradient_program())
2284 .map(|program| {
2285 program.evaluate_gradient_into_flat(
2286 amplitude_values,
2287 amplitude_gradients,
2288 &mut value_slots[..program.scratch_slots()],
2289 &mut gradient_slots[..program.scratch_slots() * grad_dim],
2290 grad_dim,
2291 )
2292 })?;
2293 let coefficient = descriptor.coefficient as f64;
2294 for (accum_item, gradient_item) in total.iter_mut().zip(parameter_gradient.iter()) {
2295 *accum_item += (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2296 }
2297 }
2298 Some(total)
2299 }
2300 fn evaluate_residual_value_ir(
2301 &self,
2302 state: &CachedIntegralCacheState,
2303 amplitude_values: &[Complex64],
2304 ) -> Complex64 {
2305 let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2306 for descriptor in &state.values {
2307 if descriptor.mul_node_index < zeroed_nodes.len() {
2308 zeroed_nodes[descriptor.mul_node_index] = true;
2309 }
2310 }
2311 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2312 state.expression_ir.evaluate_into_with_zeroed_nodes(
2313 amplitude_values,
2314 &mut value_slots,
2315 &zeroed_nodes,
2316 )
2317 }
2318 fn evaluate_residual_gradient_ir(
2319 &self,
2320 state: &CachedIntegralCacheState,
2321 amplitude_values: &[Complex64],
2322 amplitude_gradients: &[DVector<Complex64>],
2323 grad_dim: usize,
2324 ) -> DVector<Complex64> {
2325 let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2326 for descriptor in &state.values {
2327 if descriptor.mul_node_index < zeroed_nodes.len() {
2328 zeroed_nodes[descriptor.mul_node_index] = true;
2329 }
2330 }
2331 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2332 let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2333 state
2334 .expression_ir
2335 .evaluate_gradient_into_with_zeroed_nodes(
2336 amplitude_values,
2337 amplitude_gradients,
2338 &mut value_slots,
2339 &mut gradient_slots,
2340 &zeroed_nodes,
2341 )
2342 }
2343
2344 fn evaluate_weighted_value_sum_local_components(
2345 &self,
2346 parameters: &[f64],
2347 ) -> LadduResult<(f64, f64)> {
2348 let resources = self.resources.read();
2349 let parameters = resources.parameter_map.assemble(parameters)?;
2350 let amplitude_len = self.amplitude_use_sites.len();
2351 let state = self.ensure_cached_integral_cache_state(&resources)?;
2352 let lowered_artifacts = self.active_lowered_artifacts();
2353 let residual_value_slot_count = lowered_artifacts
2354 .as_ref()
2355 .and_then(|artifacts| {
2356 artifacts
2357 .residual_runtime
2358 .as_ref()
2359 .map(|runtime| runtime.value_program())
2360 .map(|program| program.scratch_slots())
2361 })
2362 .unwrap_or_else(|| self.expression_slot_count());
2363 let residual_value_program = lowered_artifacts
2364 .as_ref()
2365 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2366 .map(|runtime| runtime.value_program());
2367 let cached_parameter_indices = &state.execution_sets.cached_parameter_amplitudes;
2368 let residual_active_indices = &state.execution_sets.residual_amplitudes;
2369 debug_assert!(cached_parameter_indices.iter().all(|&index| resources
2370 .active
2371 .get(index)
2372 .copied()
2373 .unwrap_or(false)));
2374 debug_assert!(residual_active_indices.iter().all(|&index| resources
2375 .active
2376 .get(index)
2377 .copied()
2378 .unwrap_or(false)));
2379 let cached_value_sum = {
2380 if let Some(cache) = resources.caches.first() {
2381 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2382 self.fill_amplitude_values(
2383 &mut amplitude_values,
2384 cached_parameter_indices,
2385 ¶meters,
2386 cache,
2387 );
2388 lowered_artifacts
2389 .as_ref()
2390 .and_then(|artifacts| {
2391 self.evaluate_cached_weighted_value_sum_lowered(
2392 &state,
2393 artifacts,
2394 &litude_values,
2395 )
2396 })
2397 .unwrap_or_else(|| {
2398 self.evaluate_cached_weighted_value_sum_ir(&state, &litude_values)
2399 })
2400 } else {
2401 0.0
2402 }
2403 };
2404
2405 #[cfg(feature = "rayon")]
2406 let residual_sum: f64 = {
2407 resources
2408 .caches
2409 .par_iter()
2410 .zip(self.dataset.weights_local().par_iter())
2411 .map_init(
2412 || {
2413 (
2414 vec![Complex64::ZERO; amplitude_len],
2415 vec![Complex64::ZERO; residual_value_slot_count],
2416 )
2417 },
2418 |(amplitude_values, value_slots), (cache, event)| {
2419 self.fill_amplitude_values(
2420 amplitude_values,
2421 residual_active_indices,
2422 ¶meters,
2423 cache,
2424 );
2425 {
2426 let value = residual_value_program
2427 .as_ref()
2428 .map(|program| {
2429 program.evaluate_into(
2430 amplitude_values,
2431 &mut value_slots[..program.scratch_slots()],
2432 )
2433 })
2434 .unwrap_or_else(|| {
2435 self.evaluate_residual_value_ir(&state, amplitude_values)
2436 });
2437 *event * value.re
2438 }
2439 },
2440 )
2441 .sum()
2442 };
2443
2444 #[cfg(not(feature = "rayon"))]
2445 let residual_sum: f64 = {
2446 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2447 let mut value_slots = vec![Complex64::ZERO; residual_value_slot_count];
2448 resources
2449 .caches
2450 .iter()
2451 .zip(self.dataset.weights_local().iter())
2452 .map(|(cache, event)| {
2453 self.fill_amplitude_values(
2454 &mut amplitude_values,
2455 &residual_active_indices,
2456 ¶meters,
2457 cache,
2458 );
2459 {
2460 let value = residual_value_program
2461 .as_ref()
2462 .map(|program| {
2463 program.evaluate_into(
2464 &litude_values,
2465 &mut value_slots[..program.scratch_slots()],
2466 )
2467 })
2468 .unwrap_or_else(|| {
2469 self.evaluate_residual_value_ir(&state, &litude_values)
2470 });
2471 *event * value.re
2472 }
2473 })
2474 .sum()
2475 };
2476 Ok((residual_sum, cached_value_sum))
2477 }
2478
2479 pub fn evaluate_weighted_value_sum_local(&self, parameters: &[f64]) -> LadduResult<f64> {
2483 let (residual_sum, cached_value_sum) =
2484 self.evaluate_weighted_value_sum_local_components(parameters)?;
2485 Ok(residual_sum + cached_value_sum)
2486 }
2487
2488 #[cfg(feature = "mpi")]
2489 pub fn evaluate_weighted_value_sum_mpi(
2493 &self,
2494 parameters: &[f64],
2495 world: &SimpleCommunicator,
2496 ) -> LadduResult<f64> {
2497 let (residual_sum_local, cached_value_sum_local) =
2498 self.evaluate_weighted_value_sum_local_components(parameters)?;
2499 let mut residual_sum = 0.0;
2500 world.all_reduce_into(
2501 &residual_sum_local,
2502 &mut residual_sum,
2503 mpi::collective::SystemOperation::sum(),
2504 );
2505 let mut cached_value_sum = 0.0;
2506 world.all_reduce_into(
2507 &cached_value_sum_local,
2508 &mut cached_value_sum,
2509 mpi::collective::SystemOperation::sum(),
2510 );
2511 Ok(residual_sum + cached_value_sum)
2512 }
2513
2514 fn evaluate_weighted_gradient_sum_local_components(
2518 &self,
2519 parameters: &[f64],
2520 ) -> LadduResult<(DVector<f64>, DVector<f64>)> {
2521 let resources = self.resources.read();
2522 let parameters = resources.parameter_map.assemble(parameters)?;
2523 let amplitude_len = self.amplitude_use_sites.len();
2524 let grad_dim = parameters.len();
2525 let state = self.ensure_cached_integral_cache_state(&resources)?;
2526 let lowered_artifacts = self.active_lowered_artifacts();
2527 let active_index_set = resources.active_indices();
2528 let cached_parameter_indices = state
2529 .execution_sets
2530 .cached_parameter_amplitudes
2531 .iter()
2532 .copied()
2533 .filter(|index| active_index_set.binary_search(index).is_ok())
2534 .collect::<Vec<_>>();
2535 let residual_active_indices = state
2536 .execution_sets
2537 .residual_amplitudes
2538 .iter()
2539 .copied()
2540 .filter(|index| active_index_set.binary_search(index).is_ok())
2541 .collect::<Vec<_>>();
2542 let mut cached_parameter_mask = vec![false; amplitude_len];
2543 for &index in &cached_parameter_indices {
2544 cached_parameter_mask[index] = true;
2545 }
2546 let mut residual_active_mask = vec![false; amplitude_len];
2547 for &index in &residual_active_indices {
2548 residual_active_mask[index] = true;
2549 }
2550 let residual_gradient_program = lowered_artifacts
2551 .as_ref()
2552 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2553 .map(|runtime| runtime.gradient_program());
2554 let residual_gradient_slot_count = residual_gradient_program
2555 .as_ref()
2556 .map(|program| program.scratch_slots())
2557 .unwrap_or_else(|| state.expression_ir.node_count());
2558 let cached_term_sum = {
2559 if let Some(cache) = resources.caches.first() {
2560 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2561 self.fill_amplitude_values(
2562 &mut amplitude_values,
2563 &cached_parameter_indices,
2564 ¶meters,
2565 cache,
2566 );
2567 let mut amplitude_gradients = (0..amplitude_len)
2568 .map(|_| DVector::zeros(grad_dim))
2569 .collect::<Vec<_>>();
2570 self.fill_amplitude_gradients(
2571 &mut amplitude_gradients,
2572 &cached_parameter_mask,
2573 ¶meters,
2574 cache,
2575 );
2576 lowered_artifacts
2577 .as_ref()
2578 .and_then(|artifacts| {
2579 self.evaluate_cached_weighted_gradient_sum_lowered(
2580 &state,
2581 artifacts,
2582 &litude_values,
2583 &litude_gradients,
2584 grad_dim,
2585 )
2586 })
2587 .unwrap_or_else(|| {
2588 self.evaluate_cached_weighted_gradient_sum_ir(
2589 &state,
2590 &litude_values,
2591 &litude_gradients,
2592 grad_dim,
2593 )
2594 })
2595 } else {
2596 DVector::zeros(grad_dim)
2597 }
2598 };
2599
2600 #[cfg(feature = "rayon")]
2601 let residual_sum = {
2602 resources
2603 .caches
2604 .par_iter()
2605 .zip(self.dataset.weights_local().par_iter())
2606 .map_init(
2607 || {
2608 (
2609 vec![Complex64::ZERO; amplitude_len],
2610 vec![DVector::zeros(grad_dim); amplitude_len],
2611 vec![Complex64::ZERO; residual_gradient_slot_count],
2612 vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim],
2613 )
2614 },
2615 |(amplitude_values, gradient_values, value_slots, gradient_slots),
2616 (cache, event)| {
2617 self.fill_amplitude_values_and_gradients(
2618 amplitude_values,
2619 gradient_values,
2620 &residual_active_indices,
2621 &residual_active_mask,
2622 ¶meters,
2623 cache,
2624 );
2625 let gradient = residual_gradient_program
2626 .as_ref()
2627 .map(|program| {
2628 program.evaluate_gradient_into_flat(
2629 amplitude_values,
2630 gradient_values,
2631 value_slots,
2632 gradient_slots,
2633 grad_dim,
2634 )
2635 })
2636 .unwrap_or_else(|| {
2637 self.evaluate_residual_gradient_ir(
2638 &state,
2639 amplitude_values,
2640 gradient_values,
2641 grad_dim,
2642 )
2643 });
2644 gradient.map(|value| value.re).scale(*event)
2645 },
2646 )
2647 .reduce(
2648 || DVector::zeros(grad_dim),
2649 |mut accum, value| {
2650 accum += value;
2651 accum
2652 },
2653 )
2654 };
2655
2656 #[cfg(not(feature = "rayon"))]
2657 let residual_sum = {
2658 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2659 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
2660 let mut value_slots = vec![Complex64::ZERO; residual_gradient_slot_count];
2661 let mut gradient_slots = vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim];
2662 resources
2663 .caches
2664 .iter()
2665 .zip(self.dataset.weights_local().iter())
2666 .map(|(cache, event)| {
2667 self.fill_amplitude_values_and_gradients(
2668 &mut amplitude_values,
2669 &mut gradient_values,
2670 &residual_active_indices,
2671 &residual_active_mask,
2672 ¶meters,
2673 cache,
2674 );
2675 let gradient = residual_gradient_program
2676 .as_ref()
2677 .map(|program| {
2678 program.evaluate_gradient_into_flat(
2679 &litude_values,
2680 &gradient_values,
2681 &mut value_slots,
2682 &mut gradient_slots,
2683 grad_dim,
2684 )
2685 })
2686 .unwrap_or_else(|| {
2687 self.evaluate_residual_gradient_ir(
2688 &state,
2689 &litude_values,
2690 &gradient_values,
2691 grad_dim,
2692 )
2693 });
2694 gradient.map(|value| value.re).scale(*event)
2695 })
2696 .sum()
2697 };
2698 Ok((residual_sum, cached_term_sum))
2699 }
2700
2701 pub fn evaluate_weighted_gradient_sum_local(
2705 &self,
2706 parameters: &[f64],
2707 ) -> LadduResult<DVector<f64>> {
2708 let (residual_sum, cached_term_sum) =
2709 self.evaluate_weighted_gradient_sum_local_components(parameters)?;
2710 Ok(residual_sum + cached_term_sum)
2711 }
2712
2713 #[cfg(feature = "mpi")]
2714 pub fn evaluate_weighted_gradient_sum_mpi(
2718 &self,
2719 parameters: &[f64],
2720 world: &SimpleCommunicator,
2721 ) -> LadduResult<DVector<f64>> {
2722 let (residual_sum_local, cached_term_sum_local) =
2723 self.evaluate_weighted_gradient_sum_local_components(parameters)?;
2724 let mut residual_sum = vec![0.0; residual_sum_local.len()];
2725 world.all_reduce_into(
2726 residual_sum_local.as_slice(),
2727 &mut residual_sum,
2728 mpi::collective::SystemOperation::sum(),
2729 );
2730 let mut cached_term_sum = vec![0.0; cached_term_sum_local.len()];
2731 world.all_reduce_into(
2732 cached_term_sum_local.as_slice(),
2733 &mut cached_term_sum,
2734 mpi::collective::SystemOperation::sum(),
2735 );
2736 let mut total = DVector::from_vec(residual_sum);
2737 total += DVector::from_vec(cached_term_sum);
2738 Ok(total)
2739 }
2740
2741 pub fn evaluate_expression_value_with_scratch(
2742 &self,
2743 amplitude_values: &[Complex64],
2744 scratch: &mut [Complex64],
2745 ) -> Complex64 {
2746 self.evaluate_expression_runtime_value_with_scratch(amplitude_values, scratch)
2747 }
2748
2749 pub fn evaluate_expression_gradient_with_scratch(
2750 &self,
2751 amplitude_values: &[Complex64],
2752 gradient_values: &[DVector<Complex64>],
2753 value_scratch: &mut [Complex64],
2754 gradient_scratch: &mut [DVector<Complex64>],
2755 ) -> DVector<Complex64> {
2756 self.evaluate_expression_runtime_gradient_with_scratch(
2757 amplitude_values,
2758 gradient_values,
2759 value_scratch,
2760 gradient_scratch,
2761 )
2762 }
2763
2764 pub fn evaluate_expression_value_gradient_with_scratch(
2765 &self,
2766 amplitude_values: &[Complex64],
2767 gradient_values: &[DVector<Complex64>],
2768 value_scratch: &mut [Complex64],
2769 gradient_scratch: &mut [DVector<Complex64>],
2770 ) -> (Complex64, DVector<Complex64>) {
2771 self.evaluate_expression_runtime_value_gradient_with_scratch(
2772 amplitude_values,
2773 gradient_values,
2774 value_scratch,
2775 gradient_scratch,
2776 )
2777 }
2778
2779 pub fn evaluate_expression_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
2780 self.evaluate_expression_runtime_value(amplitude_values)
2781 }
2782
2783 pub fn evaluate_expression_gradient(
2784 &self,
2785 amplitude_values: &[Complex64],
2786 gradient_values: &[DVector<Complex64>],
2787 ) -> DVector<Complex64> {
2788 self.evaluate_expression_runtime_gradient(amplitude_values, gradient_values)
2789 }
2790
2791 pub fn parameters(&self) -> ParameterMap {
2793 self.resources.read().parameters()
2794 }
2795
2796 pub fn n_free(&self) -> usize {
2798 self.resources.read().n_free_parameters()
2799 }
2800
2801 pub fn n_fixed(&self) -> usize {
2803 self.resources.read().n_fixed_parameters()
2804 }
2805
2806 pub fn n_parameters(&self) -> usize {
2808 self.resources.read().n_parameters()
2809 }
2810
2811 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
2812 self.resources.read().fix_parameter(name, value)
2813 }
2814
2815 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
2816 self.resources.read().free_parameter(name)
2817 }
2818
2819 pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
2820 self.resources.write().rename_parameter(old, new)
2821 }
2822
2823 pub fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
2824 self.resources.write().rename_parameters(mapping)
2825 }
2826
2827 pub fn activate<T: AsRef<str>>(&self, name: T) {
2829 self.resources.write().activate(name);
2830 self.refresh_runtime_specializations();
2831 }
2832 pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2834 self.resources.write().activate_strict(name)?;
2835 self.refresh_runtime_specializations();
2836 Ok(())
2837 }
2838
2839 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
2841 self.resources.write().activate_many(names);
2842 self.refresh_runtime_specializations();
2843 }
2844 pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2846 self.resources.write().activate_many_strict(names)?;
2847 self.refresh_runtime_specializations();
2848 Ok(())
2849 }
2850
2851 pub fn activate_all(&self) {
2853 self.resources.write().activate_all();
2854 self.refresh_runtime_specializations();
2855 }
2856
2857 pub fn deactivate<T: AsRef<str>>(&self, name: T) {
2859 self.resources.write().deactivate(name);
2860 self.refresh_runtime_specializations();
2861 }
2862
2863 pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2865 self.resources.write().deactivate_strict(name)?;
2866 self.refresh_runtime_specializations();
2867 Ok(())
2868 }
2869
2870 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
2872 self.resources.write().deactivate_many(names);
2873 self.refresh_runtime_specializations();
2874 }
2875 pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2877 self.resources.write().deactivate_many_strict(names)?;
2878 self.refresh_runtime_specializations();
2879 Ok(())
2880 }
2881
2882 pub fn deactivate_all(&self) {
2884 self.resources.write().deactivate_all();
2885 self.refresh_runtime_specializations();
2886 }
2887
2888 pub fn isolate<T: AsRef<str>>(&self, name: T) {
2890 self.resources.write().isolate(name);
2891 self.refresh_runtime_specializations();
2892 }
2893
2894 pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2896 self.resources.write().isolate_strict(name)?;
2897 self.refresh_runtime_specializations();
2898 Ok(())
2899 }
2900
2901 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
2903 self.resources.write().isolate_many(names);
2904 self.refresh_runtime_specializations();
2905 }
2906
2907 pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2909 self.resources.write().isolate_many_strict(names)?;
2910 self.refresh_runtime_specializations();
2911 Ok(())
2912 }
2913
2914 pub fn active_mask(&self) -> Vec<bool> {
2916 self.resources.read().active.clone()
2917 }
2918
2919 pub fn set_active_mask(&self, mask: &[bool]) -> LadduResult<()> {
2921 let resources = {
2922 let mut resources = self.resources.write();
2923 if mask.len() != resources.active.len() {
2924 return Err(LadduError::LengthMismatch {
2925 context: "active amplitude mask".to_string(),
2926 expected: resources.active.len(),
2927 actual: mask.len(),
2928 });
2929 }
2930 resources.apply_active_mask(mask)?;
2931 resources.clone()
2932 };
2933 self.rebuild_runtime_specializations(&resources);
2934 Ok(())
2935 }
2936
2937 pub fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
2945 let resources = self.resources.read();
2946 let parameters = resources.parameter_map.assemble(parameters)?;
2947 let amplitude_len = self.amplitude_use_sites.len();
2948 let active_indices = resources.active_indices().to_vec();
2949 let slot_count = self.expression_value_slot_count();
2950 let program_snapshot = self.expression_value_program_snapshot();
2951 #[cfg(feature = "rayon")]
2952 {
2953 Ok(resources
2954 .caches
2955 .par_iter()
2956 .map_init(
2957 || {
2958 (
2959 vec![Complex64::ZERO; amplitude_len],
2960 vec![Complex64::ZERO; slot_count],
2961 )
2962 },
2963 |(amplitude_values, expr_slots), cache| {
2964 self.fill_amplitude_values(
2965 amplitude_values,
2966 &active_indices,
2967 ¶meters,
2968 cache,
2969 );
2970 self.evaluate_expression_value_with_program_snapshot(
2971 &program_snapshot,
2972 amplitude_values,
2973 expr_slots,
2974 )
2975 },
2976 )
2977 .collect())
2978 }
2979 #[cfg(not(feature = "rayon"))]
2980 {
2981 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2982 let mut expr_slots = vec![Complex64::ZERO; slot_count];
2983 Ok(resources
2984 .caches
2985 .iter()
2986 .map(|cache| {
2987 self.fill_amplitude_values(
2988 &mut amplitude_values,
2989 &active_indices,
2990 ¶meters,
2991 cache,
2992 );
2993 self.evaluate_expression_value_with_program_snapshot(
2994 &program_snapshot,
2995 &litude_values,
2996 &mut expr_slots,
2997 )
2998 })
2999 .collect())
3000 }
3001 }
3002
3003 pub fn evaluate_local_with_active_mask(
3005 &self,
3006 parameters: &[f64],
3007 active_mask: &[bool],
3008 ) -> LadduResult<Vec<Complex64>> {
3009 let resources = self.resources.read();
3010 if active_mask.len() != resources.active.len() {
3011 return Err(LadduError::LengthMismatch {
3012 context: "active amplitude mask".to_string(),
3013 expected: resources.active.len(),
3014 actual: active_mask.len(),
3015 });
3016 }
3017 let parameters = resources.parameter_map.assemble(parameters)?;
3018 let amplitude_len = self.amplitude_use_sites.len();
3019 let active_indices = active_mask
3020 .iter()
3021 .enumerate()
3022 .filter_map(|(index, &active)| if active { Some(index) } else { None })
3023 .collect::<Vec<_>>();
3024 let program_snapshot =
3025 self.expression_value_program_snapshot_for_active_mask(active_mask)?;
3026 let slot_count = self.expression_value_program_snapshot_slot_count(&program_snapshot);
3027 #[cfg(feature = "rayon")]
3028 {
3029 Ok(resources
3030 .caches
3031 .par_iter()
3032 .map_init(
3033 || {
3034 (
3035 vec![Complex64::ZERO; amplitude_len],
3036 vec![Complex64::ZERO; slot_count],
3037 )
3038 },
3039 |(amplitude_values, expr_slots), cache| {
3040 self.fill_amplitude_values(
3041 amplitude_values,
3042 &active_indices,
3043 ¶meters,
3044 cache,
3045 );
3046 self.evaluate_expression_value_with_program_snapshot(
3047 &program_snapshot,
3048 amplitude_values,
3049 expr_slots,
3050 )
3051 },
3052 )
3053 .collect())
3054 }
3055 #[cfg(not(feature = "rayon"))]
3056 {
3057 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3058 let mut expr_slots = vec![Complex64::ZERO; slot_count];
3059 Ok(resources
3060 .caches
3061 .iter()
3062 .map(|cache| {
3063 self.fill_amplitude_values(
3064 &mut amplitude_values,
3065 &active_indices,
3066 ¶meters,
3067 cache,
3068 );
3069 self.evaluate_expression_value_with_program_snapshot(
3070 &program_snapshot,
3071 &litude_values,
3072 &mut expr_slots,
3073 )
3074 })
3075 .collect())
3076 }
3077 }
3078
3079 #[cfg(feature = "execution-context-prototype")]
3081 pub fn evaluate_local_with_ctx(
3082 &self,
3083 parameters: &[f64],
3084 execution_context: &ExecutionContext,
3085 ) -> Vec<Complex64> {
3086 let resources = self.resources.read();
3087 let parameters = resources
3088 .parameter_map
3089 .assemble(parameters)
3090 .expect("parameter slice must match evaluator resources");
3091 let amplitude_len = self.amplitude_use_sites.len();
3092 let active_indices = resources.active_indices().to_vec();
3093 let slot_count = self.expression_value_slot_count();
3094 let program_snapshot = self.expression_value_program_snapshot();
3095 #[cfg(feature = "rayon")]
3096 {
3097 if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3098 return execution_context.install(|| {
3099 resources
3100 .caches
3101 .par_iter()
3102 .map_init(
3103 || {
3104 (
3105 vec![Complex64::ZERO; amplitude_len],
3106 vec![Complex64::ZERO; slot_count],
3107 )
3108 },
3109 |(amplitude_values, expr_slots), cache| {
3110 self.fill_amplitude_values(
3111 amplitude_values,
3112 &active_indices,
3113 ¶meters,
3114 cache,
3115 );
3116 self.evaluate_expression_value_with_program_snapshot(
3117 &program_snapshot,
3118 amplitude_values,
3119 expr_slots,
3120 )
3121 },
3122 )
3123 .collect()
3124 });
3125 }
3126 }
3127 execution_context.with_scratch(|scratch| {
3128 let (amplitude_values, expr_slots) =
3129 scratch.reserve_value_workspaces(amplitude_len, slot_count);
3130 resources
3131 .caches
3132 .iter()
3133 .map(|cache| {
3134 self.fill_amplitude_values(
3135 amplitude_values,
3136 &active_indices,
3137 ¶meters,
3138 cache,
3139 );
3140 self.evaluate_expression_value_with_program_snapshot(
3141 &program_snapshot,
3142 amplitude_values,
3143 expr_slots,
3144 )
3145 })
3146 .collect()
3147 })
3148 }
3149
3150 #[cfg(feature = "mpi")]
3158 fn evaluate_mpi(
3159 &self,
3160 parameters: &[f64],
3161 world: &SimpleCommunicator,
3162 ) -> LadduResult<Vec<Complex64>> {
3163 let local_evaluation = self.evaluate_local(parameters)?;
3164 let n_events = self.dataset.n_events();
3165 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3166 let (counts, displs) = world.get_counts_displs(n_events);
3167 {
3168 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3171 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3172 }
3173 Ok(buffer)
3174 }
3175
3176 #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3177 fn evaluate_mpi_with_ctx(
3178 &self,
3179 parameters: &[f64],
3180 world: &SimpleCommunicator,
3181 execution_context: &ExecutionContext,
3182 ) -> Vec<Complex64> {
3183 let local_evaluation = self.evaluate_local_with_ctx(parameters, execution_context);
3184 let n_events = self.dataset.n_events();
3185 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3186 let (counts, displs) = world.get_counts_displs(n_events);
3187 {
3188 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3191 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3192 }
3193 buffer
3194 }
3195
3196 pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
3199 #[cfg(feature = "mpi")]
3200 {
3201 if let Some(world) = crate::mpi::get_world() {
3202 return self.evaluate_mpi(parameters, &world);
3203 }
3204 }
3205 self.evaluate_local(parameters)
3206 }
3207
3208 #[cfg(feature = "execution-context-prototype")]
3214 pub fn evaluate_with_ctx(
3215 &self,
3216 parameters: &[f64],
3217 execution_context: &ExecutionContext,
3218 ) -> Vec<Complex64> {
3219 #[cfg(feature = "mpi")]
3220 {
3221 if let Some(world) = crate::mpi::get_world() {
3222 return self.evaluate_mpi_with_ctx(parameters, &world, execution_context);
3223 }
3224 }
3225 self.evaluate_local_with_ctx(parameters, execution_context)
3226 }
3227
3228 pub fn evaluate_batch_local(
3231 &self,
3232 parameters: &[f64],
3233 indices: &[usize],
3234 ) -> LadduResult<Vec<Complex64>> {
3235 let resources = self.resources.read();
3236 let parameters = resources.parameter_map.assemble(parameters)?;
3237 let amplitude_len = self.amplitude_use_sites.len();
3238 let active_indices = resources.active_indices().to_vec();
3239 let slot_count = self.expression_value_slot_count();
3240 let program_snapshot = self.expression_value_program_snapshot();
3241 #[cfg(feature = "rayon")]
3242 {
3243 Ok(indices
3244 .par_iter()
3245 .map_init(
3246 || {
3247 (
3248 vec![Complex64::ZERO; amplitude_len],
3249 vec![Complex64::ZERO; slot_count],
3250 )
3251 },
3252 |(amplitude_values, expr_slots), &idx| {
3253 let cache = &resources.caches[idx];
3254 self.fill_amplitude_values(
3255 amplitude_values,
3256 &active_indices,
3257 ¶meters,
3258 cache,
3259 );
3260 self.evaluate_expression_value_with_program_snapshot(
3261 &program_snapshot,
3262 amplitude_values,
3263 expr_slots,
3264 )
3265 },
3266 )
3267 .collect())
3268 }
3269 #[cfg(not(feature = "rayon"))]
3270 {
3271 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3272 let mut expr_slots = vec![Complex64::ZERO; slot_count];
3273 Ok(indices
3274 .iter()
3275 .map(|&idx| {
3276 let cache = &resources.caches[idx];
3277 self.fill_amplitude_values(
3278 &mut amplitude_values,
3279 &active_indices,
3280 ¶meters,
3281 cache,
3282 );
3283 self.evaluate_expression_value_with_program_snapshot(
3284 &program_snapshot,
3285 &litude_values,
3286 &mut expr_slots,
3287 )
3288 })
3289 .collect())
3290 }
3291 }
3292
3293 #[cfg(feature = "mpi")]
3296 fn evaluate_batch_mpi(
3297 &self,
3298 parameters: &[f64],
3299 indices: &[usize],
3300 world: &SimpleCommunicator,
3301 ) -> LadduResult<Vec<Complex64>> {
3302 let total = self.dataset.n_events();
3303 let locals = world.locals_from_globals(indices, total);
3304 let local_evaluation = self.evaluate_batch_local(parameters, &locals)?;
3305 Ok(world.all_gather_batched_partitioned(&local_evaluation, indices, total, None))
3306 }
3307
3308 pub fn evaluate_batch(
3311 &self,
3312 parameters: &[f64],
3313 indices: &[usize],
3314 ) -> LadduResult<Vec<Complex64>> {
3315 #[cfg(feature = "mpi")]
3316 {
3317 if let Some(world) = crate::mpi::get_world() {
3318 return self.evaluate_batch_mpi(parameters, indices, &world);
3319 }
3320 }
3321 self.evaluate_batch_local(parameters, indices)
3322 }
3323
3324 pub fn evaluate_gradient_local(
3332 &self,
3333 parameters: &[f64],
3334 ) -> LadduResult<Vec<DVector<Complex64>>> {
3335 let resources = self.resources.read();
3336 let parameters = resources.parameter_map.assemble(parameters)?;
3337 let amplitude_len = self.amplitude_use_sites.len();
3338 let grad_dim = parameters.len();
3339 let active_indices = resources.active_indices().to_vec();
3340 let lowered_runtime = self.lowered_runtime();
3341 let gradient_program = lowered_runtime.gradient_program();
3342 let slot_count = self.expression_gradient_slot_count();
3343 #[cfg(feature = "rayon")]
3344 {
3345 Ok(resources
3346 .caches
3347 .par_iter()
3348 .map_init(
3349 || {
3350 (
3351 vec![Complex64::ZERO; amplitude_len],
3352 vec![DVector::zeros(grad_dim); amplitude_len],
3353 vec![Complex64::ZERO; slot_count],
3354 vec![Complex64::ZERO; slot_count * grad_dim],
3355 )
3356 },
3357 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3358 self.fill_amplitude_values_and_gradients(
3359 amplitude_values,
3360 gradient_values,
3361 &active_indices,
3362 &resources.active,
3363 ¶meters,
3364 cache,
3365 );
3366 gradient_program.evaluate_gradient_into_flat(
3367 amplitude_values,
3368 gradient_values,
3369 value_slots,
3370 gradient_slots,
3371 grad_dim,
3372 )
3373 },
3374 )
3375 .collect())
3376 }
3377 #[cfg(not(feature = "rayon"))]
3378 {
3379 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3380 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3381 let mut value_slots = vec![Complex64::ZERO; slot_count];
3382 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3383 Ok(resources
3384 .caches
3385 .iter()
3386 .map(|cache| {
3387 self.fill_amplitude_values_and_gradients(
3388 &mut amplitude_values,
3389 &mut gradient_values,
3390 &active_indices,
3391 &resources.active,
3392 ¶meters,
3393 cache,
3394 );
3395 gradient_program.evaluate_gradient_into_flat(
3396 &litude_values,
3397 &gradient_values,
3398 &mut value_slots,
3399 &mut gradient_slots,
3400 grad_dim,
3401 )
3402 })
3403 .collect())
3404 }
3405 }
3406
3407 #[cfg(feature = "execution-context-prototype")]
3409 pub fn evaluate_gradient_local_with_ctx(
3410 &self,
3411 parameters: &[f64],
3412 execution_context: &ExecutionContext,
3413 ) -> Vec<DVector<Complex64>> {
3414 let resources = self.resources.read();
3415 let parameters = resources
3416 .parameter_map
3417 .assemble(parameters)
3418 .expect("parameter slice must match evaluator resources");
3419 let amplitude_len = self.amplitude_use_sites.len();
3420 let grad_dim = parameters.len();
3421 let active_indices = resources.active_indices().to_vec();
3422 let slot_count = self.expression_slot_count();
3423 #[cfg(feature = "rayon")]
3424 {
3425 if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3426 return execution_context.install(|| {
3427 resources
3428 .caches
3429 .par_iter()
3430 .map_init(
3431 || {
3432 (
3433 vec![Complex64::ZERO; amplitude_len],
3434 vec![DVector::zeros(grad_dim); amplitude_len],
3435 vec![Complex64::ZERO; slot_count],
3436 vec![DVector::zeros(grad_dim); slot_count],
3437 )
3438 },
3439 |(amplitude_values, gradient_values, value_slots, gradient_slots),
3440 cache| {
3441 self.evaluate_cache_gradient_with_scratch(
3442 amplitude_values,
3443 gradient_values,
3444 value_slots,
3445 gradient_slots,
3446 &active_indices,
3447 &resources.active,
3448 ¶meters,
3449 cache,
3450 )
3451 },
3452 )
3453 .collect()
3454 });
3455 }
3456 }
3457 execution_context.with_scratch(|scratch| {
3458 let (amplitude_values, value_slots, gradient_values, gradient_slots) =
3459 scratch.reserve_gradient_workspaces(amplitude_len, slot_count, grad_dim);
3460 resources
3461 .caches
3462 .iter()
3463 .map(|cache| {
3464 self.evaluate_cache_gradient_with_scratch(
3465 amplitude_values,
3466 gradient_values,
3467 value_slots,
3468 gradient_slots,
3469 &active_indices,
3470 &resources.active,
3471 ¶meters,
3472 cache,
3473 )
3474 })
3475 .collect()
3476 })
3477 }
3478
3479 #[cfg(feature = "mpi")]
3487 fn evaluate_gradient_mpi(
3488 &self,
3489 parameters: &[f64],
3490 world: &SimpleCommunicator,
3491 ) -> LadduResult<Vec<DVector<Complex64>>> {
3492 let local_evaluation = self.evaluate_gradient_local(parameters)?;
3493 let n_events = self.dataset.n_events();
3494 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
3495 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
3496 {
3497 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3500 world.all_gather_varcount_into(
3501 &local_evaluation
3502 .iter()
3503 .flat_map(|v| v.data.as_vec())
3504 .copied()
3505 .collect::<Vec<_>>(),
3506 &mut partitioned_buffer,
3507 );
3508 }
3509 Ok(buffer
3510 .chunks(parameters.len())
3511 .map(DVector::from_row_slice)
3512 .collect())
3513 }
3514
3515 #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3516 fn evaluate_gradient_mpi_with_ctx(
3517 &self,
3518 parameters: &[f64],
3519 world: &SimpleCommunicator,
3520 execution_context: &ExecutionContext,
3521 ) -> Vec<DVector<Complex64>> {
3522 let local_evaluation = self.evaluate_gradient_local_with_ctx(parameters, execution_context);
3523 let n_events = self.dataset.n_events();
3524 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
3525 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
3526 {
3527 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3530 world.all_gather_varcount_into(
3531 &local_evaluation
3532 .iter()
3533 .flat_map(|v| v.data.as_vec())
3534 .copied()
3535 .collect::<Vec<_>>(),
3536 &mut partitioned_buffer,
3537 );
3538 }
3539 buffer
3540 .chunks(parameters.len())
3541 .map(DVector::from_row_slice)
3542 .collect()
3543 }
3544
3545 pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<Vec<DVector<Complex64>>> {
3548 #[cfg(feature = "mpi")]
3549 {
3550 if let Some(world) = crate::mpi::get_world() {
3551 return self.evaluate_gradient_mpi(parameters, &world);
3552 }
3553 }
3554 self.evaluate_gradient_local(parameters)
3555 }
3556
3557 #[cfg(feature = "execution-context-prototype")]
3563 pub fn evaluate_gradient_with_ctx(
3564 &self,
3565 parameters: &[f64],
3566 execution_context: &ExecutionContext,
3567 ) -> Vec<DVector<Complex64>> {
3568 #[cfg(feature = "mpi")]
3569 {
3570 if let Some(world) = crate::mpi::get_world() {
3571 return self.evaluate_gradient_mpi_with_ctx(parameters, &world, execution_context);
3572 }
3573 }
3574 self.evaluate_gradient_local_with_ctx(parameters, execution_context)
3575 }
3576
3577 pub fn evaluate_gradient_batch_local(
3580 &self,
3581 parameters: &[f64],
3582 indices: &[usize],
3583 ) -> LadduResult<Vec<DVector<Complex64>>> {
3584 let resources = self.resources.read();
3585 let parameters = resources.parameter_map.assemble(parameters)?;
3586 let amplitude_len = self.amplitude_use_sites.len();
3587 let grad_dim = parameters.len();
3588 let active_indices = resources.active_indices().to_vec();
3589 let lowered_runtime = self.lowered_runtime();
3590 let gradient_program = lowered_runtime.gradient_program();
3591 let slot_count = self.expression_gradient_slot_count();
3592 #[cfg(feature = "rayon")]
3593 {
3594 Ok(indices
3595 .par_iter()
3596 .map_init(
3597 || {
3598 (
3599 vec![Complex64::ZERO; amplitude_len],
3600 vec![DVector::zeros(grad_dim); amplitude_len],
3601 vec![Complex64::ZERO; slot_count],
3602 vec![Complex64::ZERO; slot_count * grad_dim],
3603 )
3604 },
3605 |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
3606 let cache = &resources.caches[idx];
3607 self.fill_amplitude_values_and_gradients(
3608 amplitude_values,
3609 gradient_values,
3610 &active_indices,
3611 &resources.active,
3612 ¶meters,
3613 cache,
3614 );
3615 gradient_program.evaluate_gradient_into_flat(
3616 amplitude_values,
3617 gradient_values,
3618 value_slots,
3619 gradient_slots,
3620 grad_dim,
3621 )
3622 },
3623 )
3624 .collect())
3625 }
3626 #[cfg(not(feature = "rayon"))]
3627 {
3628 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3629 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3630 let mut value_slots = vec![Complex64::ZERO; slot_count];
3631 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3632 Ok(indices
3633 .iter()
3634 .map(|&idx| {
3635 let cache = &resources.caches[idx];
3636 self.fill_amplitude_values_and_gradients(
3637 &mut amplitude_values,
3638 &mut gradient_values,
3639 &active_indices,
3640 &resources.active,
3641 ¶meters,
3642 cache,
3643 );
3644 gradient_program.evaluate_gradient_into_flat(
3645 &litude_values,
3646 &gradient_values,
3647 &mut value_slots,
3648 &mut gradient_slots,
3649 grad_dim,
3650 )
3651 })
3652 .collect())
3653 }
3654 }
3655
3656 #[cfg(feature = "mpi")]
3659 fn evaluate_gradient_batch_mpi(
3660 &self,
3661 parameters: &[f64],
3662 indices: &[usize],
3663 world: &SimpleCommunicator,
3664 ) -> LadduResult<Vec<DVector<Complex64>>> {
3665 let total = self.dataset.n_events();
3666 let locals = world.locals_from_globals(indices, total);
3667 let flattened_local_evaluation = self
3668 .evaluate_gradient_batch_local(parameters, &locals)?
3669 .iter()
3670 .flat_map(|g| g.data.as_vec().to_vec())
3671 .collect::<Vec<Complex64>>();
3672 Ok(world
3673 .all_gather_batched_partitioned(
3674 &flattened_local_evaluation,
3675 indices,
3676 total,
3677 Some(parameters.len()),
3678 )
3679 .chunks(parameters.len())
3680 .map(DVector::from_row_slice)
3681 .collect())
3682 }
3683
3684 pub fn evaluate_gradient_batch(
3688 &self,
3689 parameters: &[f64],
3690 indices: &[usize],
3691 ) -> LadduResult<Vec<DVector<Complex64>>> {
3692 #[cfg(feature = "mpi")]
3693 {
3694 if let Some(world) = crate::mpi::get_world() {
3695 return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
3696 }
3697 }
3698 self.evaluate_gradient_batch_local(parameters, indices)
3699 }
3700
3701 pub fn evaluate_with_gradient_local(
3703 &self,
3704 parameters: &[f64],
3705 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3706 let resources = self.resources.read();
3707 let parameters = resources.parameter_map.assemble(parameters)?;
3708 let amplitude_len = self.amplitude_use_sites.len();
3709 let grad_dim = parameters.len();
3710 let active_indices = resources.active_indices().to_vec();
3711 let lowered_runtime = self.lowered_runtime();
3712 let value_gradient_program = lowered_runtime.value_gradient_program();
3713 let slot_count = self.expression_value_gradient_slot_count();
3714 #[cfg(feature = "rayon")]
3715 {
3716 Ok(resources
3717 .caches
3718 .par_iter()
3719 .map_init(
3720 || {
3721 (
3722 vec![Complex64::ZERO; amplitude_len],
3723 vec![DVector::zeros(grad_dim); amplitude_len],
3724 vec![Complex64::ZERO; slot_count],
3725 vec![Complex64::ZERO; slot_count * grad_dim],
3726 )
3727 },
3728 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3729 self.fill_amplitude_values_and_gradients(
3730 amplitude_values,
3731 gradient_values,
3732 &active_indices,
3733 &resources.active,
3734 ¶meters,
3735 cache,
3736 );
3737 value_gradient_program.evaluate_value_gradient_into_flat(
3738 amplitude_values,
3739 gradient_values,
3740 value_slots,
3741 gradient_slots,
3742 grad_dim,
3743 )
3744 },
3745 )
3746 .collect())
3747 }
3748 #[cfg(not(feature = "rayon"))]
3749 {
3750 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3751 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3752 let mut value_slots = vec![Complex64::ZERO; slot_count];
3753 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3754 Ok(resources
3755 .caches
3756 .iter()
3757 .map(|cache| {
3758 self.fill_amplitude_values_and_gradients(
3759 &mut amplitude_values,
3760 &mut gradient_values,
3761 &active_indices,
3762 &resources.active,
3763 ¶meters,
3764 cache,
3765 );
3766 value_gradient_program.evaluate_value_gradient_into_flat(
3767 &litude_values,
3768 &gradient_values,
3769 &mut value_slots,
3770 &mut gradient_slots,
3771 grad_dim,
3772 )
3773 })
3774 .collect())
3775 }
3776 }
3777
3778 pub fn evaluate_with_gradient_local_with_active_mask(
3780 &self,
3781 parameters: &[f64],
3782 active_mask: &[bool],
3783 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3784 let resources = self.resources.read();
3785 if active_mask.len() != resources.active.len() {
3786 return Err(LadduError::LengthMismatch {
3787 context: "active amplitude mask".to_string(),
3788 expected: resources.active.len(),
3789 actual: active_mask.len(),
3790 });
3791 }
3792 let parameters = resources.parameter_map.assemble(parameters)?;
3793 let amplitude_len = self.amplitude_use_sites.len();
3794 let grad_dim = parameters.len();
3795 let active_indices = active_mask
3796 .iter()
3797 .enumerate()
3798 .filter_map(|(index, &active)| if active { Some(index) } else { None })
3799 .collect::<Vec<_>>();
3800 let lowered_runtime = self.lower_expression_runtime_for_active_mask(active_mask)?;
3801 let slot_count = lowered_runtime.value_gradient_program().scratch_slots();
3802 #[cfg(feature = "rayon")]
3803 {
3804 Ok(resources
3805 .caches
3806 .par_iter()
3807 .map_init(
3808 || {
3809 (
3810 vec![Complex64::ZERO; amplitude_len],
3811 vec![DVector::zeros(grad_dim); amplitude_len],
3812 vec![Complex64::ZERO; slot_count],
3813 vec![Complex64::ZERO; slot_count * grad_dim],
3814 )
3815 },
3816 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3817 self.fill_amplitude_values_and_gradients(
3818 amplitude_values,
3819 gradient_values,
3820 &active_indices,
3821 active_mask,
3822 ¶meters,
3823 cache,
3824 );
3825 lowered_runtime
3826 .value_gradient_program()
3827 .evaluate_value_gradient_into_flat(
3828 amplitude_values,
3829 gradient_values,
3830 value_slots,
3831 gradient_slots,
3832 grad_dim,
3833 )
3834 },
3835 )
3836 .collect())
3837 }
3838 #[cfg(not(feature = "rayon"))]
3839 {
3840 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3841 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3842 let mut value_slots = vec![Complex64::ZERO; slot_count];
3843 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3844 Ok(resources
3845 .caches
3846 .iter()
3847 .map(|cache| {
3848 self.fill_amplitude_values_and_gradients(
3849 &mut amplitude_values,
3850 &mut gradient_values,
3851 &active_indices,
3852 active_mask,
3853 ¶meters,
3854 cache,
3855 );
3856 lowered_runtime
3857 .value_gradient_program()
3858 .evaluate_value_gradient_into_flat(
3859 &litude_values,
3860 &gradient_values,
3861 &mut value_slots,
3862 &mut gradient_slots,
3863 grad_dim,
3864 )
3865 })
3866 .collect())
3867 }
3868 }
3869
3870 pub fn evaluate_with_gradient_batch_local(
3872 &self,
3873 parameters: &[f64],
3874 indices: &[usize],
3875 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3876 let resources = self.resources.read();
3877 let parameters = resources.parameter_map.assemble(parameters)?;
3878 let amplitude_len = self.amplitude_use_sites.len();
3879 let grad_dim = parameters.len();
3880 let active_indices = resources.active_indices().to_vec();
3881 let lowered_runtime = self.lowered_runtime();
3882 let value_gradient_program = lowered_runtime.value_gradient_program();
3883 let slot_count = self.expression_value_gradient_slot_count();
3884 #[cfg(feature = "rayon")]
3885 {
3886 Ok(indices
3887 .par_iter()
3888 .map_init(
3889 || {
3890 (
3891 vec![Complex64::ZERO; amplitude_len],
3892 vec![DVector::zeros(grad_dim); amplitude_len],
3893 vec![Complex64::ZERO; slot_count],
3894 vec![Complex64::ZERO; slot_count * grad_dim],
3895 )
3896 },
3897 |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
3898 let cache = &resources.caches[idx];
3899 self.fill_amplitude_values_and_gradients(
3900 amplitude_values,
3901 gradient_values,
3902 &active_indices,
3903 &resources.active,
3904 ¶meters,
3905 cache,
3906 );
3907 value_gradient_program.evaluate_value_gradient_into_flat(
3908 amplitude_values,
3909 gradient_values,
3910 value_slots,
3911 gradient_slots,
3912 grad_dim,
3913 )
3914 },
3915 )
3916 .collect())
3917 }
3918 #[cfg(not(feature = "rayon"))]
3919 {
3920 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3921 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3922 let mut value_slots = vec![Complex64::ZERO; slot_count];
3923 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3924 Ok(indices
3925 .iter()
3926 .map(|&idx| {
3927 let cache = &resources.caches[idx];
3928 self.fill_amplitude_values_and_gradients(
3929 &mut amplitude_values,
3930 &mut gradient_values,
3931 &active_indices,
3932 &resources.active,
3933 ¶meters,
3934 cache,
3935 );
3936 value_gradient_program.evaluate_value_gradient_into_flat(
3937 &litude_values,
3938 &gradient_values,
3939 &mut value_slots,
3940 &mut gradient_slots,
3941 grad_dim,
3942 )
3943 })
3944 .collect())
3945 }
3946 }
3947}
3948
3949#[cfg(test)]
3950mod tests {
3951 use approx::assert_relative_eq;
3952 #[cfg(feature = "mpi")]
3953 use mpi_test::mpi_test;
3954 use serde::{Deserialize, Serialize};
3955
3956 use super::*;
3957 use crate::{
3958 amplitude::{AmplitudeID, Tags, TestAmplitude},
3959 data::{test_dataset, test_event, DatasetMetadata, Event, EventData},
3960 parameter,
3961 parameters::Parameter,
3962 resources::{Cache, ParameterID, Parameters, Resources, ScalarID},
3963 vectors::Vec4,
3964 };
3965
3966 #[derive(Clone, Serialize, Deserialize)]
3967 pub struct ComplexScalar {
3968 name: String,
3969 re: Parameter,
3970 pid_re: ParameterID,
3971 im: Parameter,
3972 pid_im: ParameterID,
3973 }
3974
3975 impl ComplexScalar {
3976 #[allow(clippy::new_ret_no_self)]
3977 pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
3978 Self {
3979 name: name.to_string(),
3980 re,
3981 pid_re: Default::default(),
3982 im,
3983 pid_im: Default::default(),
3984 }
3985 .into_expression()
3986 }
3987 }
3988
3989 #[typetag::serde]
3990 impl Amplitude for ComplexScalar {
3991 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
3992 self.pid_re = resources.register_parameter(&self.re)?;
3993 self.pid_im = resources.register_parameter(&self.im)?;
3994 resources.register_amplitude(&self.name)
3995 }
3996
3997 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
3998 Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
3999 }
4000
4001 fn compute_gradient(
4002 &self,
4003 parameters: &Parameters,
4004 _cache: &Cache,
4005 gradient: &mut DVector<Complex64>,
4006 ) {
4007 if let Some(ind) = parameters.free_index(self.pid_re) {
4008 gradient[ind] = Complex64::ONE;
4009 }
4010 if let Some(ind) = parameters.free_index(self.pid_im) {
4011 gradient[ind] = Complex64::I;
4012 }
4013 }
4014 }
4015
4016 #[derive(Clone, Serialize, Deserialize)]
4017 pub struct ParameterOnlyScalar {
4018 name: String,
4019 value: Parameter,
4020 pid: ParameterID,
4021 }
4022
4023 impl ParameterOnlyScalar {
4024 #[allow(clippy::new_ret_no_self)]
4025 pub fn new(name: &str, value: Parameter) -> LadduResult<Expression> {
4026 Self {
4027 name: name.to_string(),
4028 value,
4029 pid: Default::default(),
4030 }
4031 .into_expression()
4032 }
4033 }
4034
4035 #[typetag::serde]
4036 impl Amplitude for ParameterOnlyScalar {
4037 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4038 self.pid = resources.register_parameter(&self.value)?;
4039 resources.register_amplitude(&self.name)
4040 }
4041
4042 fn dependence_hint(&self) -> ExpressionDependence {
4043 ExpressionDependence::ParameterOnly
4044 }
4045
4046 fn real_valued_hint(&self) -> bool {
4047 true
4048 }
4049
4050 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4051 Complex64::new(parameters.get(self.pid), 0.0)
4052 }
4053 }
4054
4055 #[derive(Clone, Serialize, Deserialize)]
4056 pub struct CacheOnlyScalar {
4057 name: String,
4058 beam_energy: ScalarID,
4059 }
4060
4061 impl CacheOnlyScalar {
4062 #[allow(clippy::new_ret_no_self)]
4063 pub fn new(name: &str) -> LadduResult<Expression> {
4064 Self {
4065 name: name.to_string(),
4066 beam_energy: Default::default(),
4067 }
4068 .into_expression()
4069 }
4070 }
4071
4072 #[typetag::serde]
4073 impl Amplitude for CacheOnlyScalar {
4074 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4075 self.beam_energy =
4076 resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
4077 resources.register_amplitude(&self.name)
4078 }
4079
4080 fn dependence_hint(&self) -> ExpressionDependence {
4081 ExpressionDependence::CacheOnly
4082 }
4083
4084 fn real_valued_hint(&self) -> bool {
4085 true
4086 }
4087
4088 fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {
4089 cache.store_scalar(self.beam_energy, event.p4_at(0).e());
4090 }
4091
4092 fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
4093 Complex64::new(cache.get_scalar(self.beam_energy), 0.0)
4094 }
4095 }
4096
4097 #[derive(Clone, Copy)]
4098 enum DeterministicFixtureKind {
4099 Separable,
4100 Partial,
4101 NonSeparable,
4102 }
4103
4104 struct DeterministicFixture {
4105 expression: Expression,
4106 dataset: Arc<Dataset>,
4107 parameters: Vec<f64>,
4108 }
4109
4110 const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
4111 const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
4112
4113 fn deterministic_fixture_dataset() -> Arc<Dataset> {
4114 let metadata = Arc::new(DatasetMetadata::default());
4115 let events = vec![
4116 Arc::new(EventData {
4117 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
4118 aux: vec![],
4119 weight: 0.5,
4120 }),
4121 Arc::new(EventData {
4122 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
4123 aux: vec![],
4124 weight: -1.25,
4125 }),
4126 Arc::new(EventData {
4127 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
4128 aux: vec![],
4129 weight: 2.0,
4130 }),
4131 Arc::new(EventData {
4132 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
4133 aux: vec![],
4134 weight: 0.75,
4135 }),
4136 ];
4137 Arc::new(Dataset::new_with_metadata(events, metadata))
4138 }
4139
4140 fn make_deterministic_fixture(kind: DeterministicFixtureKind) -> DeterministicFixture {
4141 let dataset = deterministic_fixture_dataset();
4142 match kind {
4143 DeterministicFixtureKind::Separable => {
4144 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1"))
4145 .expect("separable p1 should build");
4146 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2"))
4147 .expect("separable p2 should build");
4148 let c1 = CacheOnlyScalar::new("c1").expect("separable c1 should build");
4149 let c2 = CacheOnlyScalar::new("c2").expect("separable c2 should build");
4150 DeterministicFixture {
4151 expression: (&p1 * &c1) + &(&p2 * &c2),
4152 dataset,
4153 parameters: vec![0.4, -0.3],
4154 }
4155 }
4156 DeterministicFixtureKind::Partial => {
4157 let p =
4158 ParameterOnlyScalar::new("p", parameter!("p")).expect("partial p should build");
4159 let c = CacheOnlyScalar::new("c").expect("partial c should build");
4160 let m = TestAmplitude::new("m", parameter!("mr"), parameter!("mi"))
4161 .expect("partial m should build");
4162 DeterministicFixture {
4163 expression: (&p * &c) + &m,
4164 dataset,
4165 parameters: vec![0.55, 0.2, -0.15],
4166 }
4167 }
4168 DeterministicFixtureKind::NonSeparable => {
4169 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i"))
4170 .expect("non-separable m1 should build");
4171 let m2 = TestAmplitude::new("m2", parameter!("m2r"), parameter!("m2i"))
4172 .expect("non-separable m2 should build");
4173 DeterministicFixture {
4174 expression: &m1 * &m2,
4175 dataset,
4176 parameters: vec![0.25, -0.4, 0.6, 0.1],
4177 }
4178 }
4179 }
4180 }
4181
4182 fn assert_weighted_sum_matches_eventwise_baseline(fixture: &DeterministicFixture) {
4183 let evaluator = fixture
4184 .expression
4185 .load(&fixture.dataset)
4186 .expect("fixture evaluator should load");
4187 let expected_value = evaluator
4188 .evaluate_local(&fixture.parameters)
4189 .expect("evaluation should succeed")
4190 .iter()
4191 .zip(fixture.dataset.weights_local().iter())
4192 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
4193 let expected_gradient = evaluator
4194 .evaluate_gradient_local(&fixture.parameters)
4195 .expect("evaluation should succeed")
4196 .iter()
4197 .zip(fixture.dataset.weights_local().iter())
4198 .fold(
4199 DVector::zeros(fixture.parameters.len()),
4200 |mut accum, (gradient, event)| {
4201 accum += gradient.map(|value| value.re).scale(*event);
4202 accum
4203 },
4204 );
4205 let actual_value = evaluator
4206 .evaluate_weighted_value_sum_local(&fixture.parameters)
4207 .expect("evaluation should succeed");
4208 let actual_gradient = evaluator
4209 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4210 .expect("evaluation should succeed");
4211 assert_relative_eq!(
4212 actual_value,
4213 expected_value,
4214 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4215 max_relative = DETERMINISTIC_STRICT_REL_TOL
4216 );
4217 for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
4218 assert_relative_eq!(
4219 *actual_item,
4220 *expected_item,
4221 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4222 max_relative = DETERMINISTIC_STRICT_REL_TOL
4223 );
4224 }
4225 }
4226 fn assert_mixed_normalization_components_match_combined_path(fixture: &DeterministicFixture) {
4227 let evaluator = fixture
4228 .expression
4229 .load(&fixture.dataset)
4230 .expect("fixture evaluator should load");
4231 let state = {
4232 let resources = evaluator.resources.read();
4233 evaluator.ensure_cached_integral_cache_state(&resources)
4234 }
4235 .expect("state should be available");
4236 assert!(
4237 !state.values.is_empty(),
4238 "fixture should exercise cached normalization terms"
4239 );
4240 assert!(
4241 !state.execution_sets.residual_amplitudes.is_empty(),
4242 "fixture should exercise residual normalization amplitudes"
4243 );
4244
4245 let (residual_value_sum, cached_value_sum) = evaluator
4246 .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4247 .expect("evaluation should succeed");
4248 assert!(residual_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4249 assert!(cached_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4250 let combined_value = evaluator
4251 .evaluate_weighted_value_sum_local(&fixture.parameters)
4252 .expect("evaluation should succeed");
4253 assert_relative_eq!(
4254 residual_value_sum + cached_value_sum,
4255 combined_value,
4256 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4257 max_relative = DETERMINISTIC_STRICT_REL_TOL
4258 );
4259
4260 let (residual_gradient_sum, cached_gradient_sum) = evaluator
4261 .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4262 .expect("evaluation should succeed");
4263 let combined_gradient = evaluator
4264 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4265 .expect("evaluation should succeed");
4266 assert!(residual_gradient_sum
4267 .iter()
4268 .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4269 assert!(cached_gradient_sum
4270 .iter()
4271 .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4272 for ((residual_item, cached_item), combined_item) in residual_gradient_sum
4273 .iter()
4274 .zip(cached_gradient_sum.iter())
4275 .zip(combined_gradient.iter())
4276 {
4277 assert_relative_eq!(
4278 residual_item + cached_item,
4279 *combined_item,
4280 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4281 max_relative = DETERMINISTIC_STRICT_REL_TOL
4282 );
4283 }
4284 }
4285
4286 #[test]
4287 fn test_deterministic_fixture_weighted_sums_stable_across_activation_mask_toggle() {
4288 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4289 let evaluator = fixture
4290 .expression
4291 .load(&fixture.dataset)
4292 .expect("fixture evaluator should load");
4293 let original_mask = evaluator.active_mask();
4294
4295 let original_value = evaluator
4296 .evaluate_weighted_value_sum_local(&fixture.parameters)
4297 .expect("evaluation should succeed");
4298
4299 evaluator.isolate_many(&["p", "c"]);
4300 assert_ne!(evaluator.active_mask(), original_mask);
4301
4302 evaluator
4303 .set_active_mask(&original_mask)
4304 .expect("original fixture active mask should restore");
4305 assert_eq!(evaluator.active_mask(), original_mask);
4306 let actual_value = evaluator
4307 .evaluate_weighted_value_sum_local(&fixture.parameters)
4308 .expect("evaluation should succeed");
4309 assert_relative_eq!(
4310 actual_value,
4311 original_value,
4312 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4313 max_relative = DETERMINISTIC_STRICT_REL_TOL
4314 );
4315 }
4316
4317 #[test]
4318 fn test_deterministic_fixtures_match_eventwise_weighted_sums() {
4319 let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4320 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4321 let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4322
4323 assert_weighted_sum_matches_eventwise_baseline(&separable);
4324 assert_weighted_sum_matches_eventwise_baseline(&partial);
4325 assert_weighted_sum_matches_eventwise_baseline(&non_separable);
4326 }
4327 #[test]
4328 fn test_deterministic_fixtures_cover_separable_partial_non_separable_models() {
4329 let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4330 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4331 let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4332
4333 let separable_evaluator = separable
4334 .expression
4335 .load(&separable.dataset)
4336 .expect("separable evaluator should load");
4337 let partial_evaluator = partial
4338 .expression
4339 .load(&partial.dataset)
4340 .expect("partial evaluator should load");
4341 let non_separable_evaluator = non_separable
4342 .expression
4343 .load(&non_separable.dataset)
4344 .expect("non-separable evaluator should load");
4345
4346 assert_eq!(
4347 separable_evaluator
4348 .expression_precomputed_cached_integrals()
4349 .expect("integrals should be computed")
4350 .len(),
4351 2
4352 );
4353 assert_eq!(
4354 partial_evaluator
4355 .expression_precomputed_cached_integrals()
4356 .expect("integrals should be computed")
4357 .len(),
4358 1
4359 );
4360 assert!(non_separable_evaluator
4361 .expression_precomputed_cached_integrals()
4362 .expect("integrals should be computed")
4363 .is_empty());
4364 }
4365 #[test]
4366 fn test_partial_fixture_combined_normalization_components_match_total() {
4367 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4368 assert_mixed_normalization_components_match_combined_path(&partial);
4369 }
4370 #[test]
4371 fn test_non_separable_fixture_normalization_components_stay_residual_only() {
4372 let fixture = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4373 let evaluator = fixture
4374 .expression
4375 .load(&fixture.dataset)
4376 .expect("fixture evaluator should load");
4377 let resources = evaluator.resources.read();
4378 let state = evaluator
4379 .ensure_cached_integral_cache_state(&resources)
4380 .expect("state should be available");
4381 assert!(state.values.is_empty());
4382
4383 let (residual_value_sum, cached_value_sum) = evaluator
4384 .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4385 .expect("evaluation should succeed");
4386 assert_relative_eq!(
4387 cached_value_sum,
4388 0.0,
4389 epsilon = DETERMINISTIC_STRICT_ABS_TOL
4390 );
4391 assert_relative_eq!(
4392 residual_value_sum,
4393 evaluator
4394 .evaluate_weighted_value_sum_local(&fixture.parameters)
4395 .expect("evaluation should succeed"),
4396 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4397 max_relative = DETERMINISTIC_STRICT_REL_TOL
4398 );
4399
4400 let (residual_gradient_sum, cached_gradient_sum) = evaluator
4401 .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4402 .expect("evaluation should succeed");
4403 assert!(cached_gradient_sum
4404 .iter()
4405 .all(|value| value.abs() <= DETERMINISTIC_STRICT_ABS_TOL));
4406 let combined_gradient = evaluator
4407 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4408 .expect("evaluation should succeed");
4409 for (residual_item, combined_item) in
4410 residual_gradient_sum.iter().zip(combined_gradient.iter())
4411 {
4412 assert_relative_eq!(
4413 *residual_item,
4414 *combined_item,
4415 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4416 max_relative = DETERMINISTIC_STRICT_REL_TOL
4417 );
4418 }
4419 }
4420
4421 #[test]
4422 fn test_batch_evaluation() {
4423 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag")).unwrap();
4424 let mut event1 = test_event();
4425 event1.p4s[0].t = 10.0;
4426 let mut event2 = test_event();
4427 event2.p4s[0].t = 11.0;
4428 let mut event3 = test_event();
4429 event3.p4s[0].t = 12.0;
4430 let dataset = Arc::new(Dataset::new_with_metadata(
4431 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
4432 Arc::new(DatasetMetadata::default()),
4433 ));
4434 let evaluator = expr.load(&dataset).unwrap();
4435 let result = evaluator
4436 .evaluate_batch(&[1.1, 2.2], &[0, 2])
4437 .expect("evaluation should succeed");
4438 assert_eq!(result.len(), 2);
4439 assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
4440 assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
4441 let result_grad = evaluator
4442 .evaluate_gradient_batch(&[1.1, 2.2], &[0, 2])
4443 .expect("evaluation should succeed");
4444 assert_eq!(result_grad.len(), 2);
4445 assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
4446 assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
4447 assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
4448 assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
4449 }
4450
4451 #[test]
4452 fn test_load_compiles_expression_ir_once() {
4453 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4454 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4455 .norm_sqr();
4456 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4457 let evaluator = expr.load(&dataset).unwrap();
4458 assert!(evaluator.expression_slot_count() > 0);
4459 }
4460 #[test]
4461 fn test_expression_ir_value_matches_lowered_runtime() {
4462 let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4463 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4464 * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
4465 .conj()
4466 .norm_sqr();
4467 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4468 let evaluator = expr.load(&dataset).unwrap();
4469 let resources = evaluator.resources.read();
4470 let parameters = resources
4471 .parameter_map
4472 .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
4473 .expect("parameters should assemble");
4474 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4475 evaluator.fill_amplitude_values(
4476 &mut amplitude_values,
4477 resources.active_indices(),
4478 ¶meters,
4479 &resources.caches[0],
4480 );
4481 let mut ir_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4482 let lowered_runtime = evaluator.lowered_runtime();
4483 let lowered_program = lowered_runtime.value_program();
4484 let mut lowered_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4485 let lowered_value =
4486 evaluator.evaluate_expression_value_with_scratch(&litude_values, &mut ir_slots);
4487 let direct_lowered_value =
4488 lowered_program.evaluate_into(&litude_values, &mut lowered_slots);
4489 let ir_value = evaluator
4490 .expression_ir()
4491 .evaluate_into(&litude_values, &mut ir_slots);
4492 assert_relative_eq!(lowered_value.re, direct_lowered_value.re);
4493 assert_relative_eq!(lowered_value.im, direct_lowered_value.im);
4494 assert_relative_eq!(lowered_value.re, ir_value.re);
4495 assert_relative_eq!(lowered_value.im, ir_value.im);
4496 }
4497 #[test]
4498 fn test_expression_ir_load_initializes_with_lowered_value_runtime() {
4499 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai"))
4500 .unwrap()
4501 .norm_sqr();
4502 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4503 let evaluator = expr.load(&dataset).unwrap();
4504 let lowered_runtime = evaluator.lowered_runtime();
4505 assert_eq!(
4506 lowered_runtime.value_program().kind(),
4507 lowered::LoweredProgramKind::Value
4508 );
4509 assert_eq!(
4510 lowered_runtime.gradient_program().kind(),
4511 lowered::LoweredProgramKind::Gradient
4512 );
4513 assert_eq!(
4514 lowered_runtime.value_gradient_program().kind(),
4515 lowered::LoweredProgramKind::ValueGradient
4516 );
4517 }
4518 #[test]
4519 fn test_expression_ir_gradient_matches_lowered_runtime() {
4520 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4521 * TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4522 .norm_sqr();
4523 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4524 let evaluator = expr.load(&dataset).unwrap();
4525 let resources = evaluator.resources.read();
4526 let parameters = resources
4527 .parameter_map
4528 .assemble(&[1.0, 0.25, -0.8, 0.5])
4529 .expect("parameters should assemble");
4530 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4531 evaluator.fill_amplitude_values(
4532 &mut amplitude_values,
4533 resources.active_indices(),
4534 ¶meters,
4535 &resources.caches[0],
4536 );
4537 let mut active_mask = vec![false; evaluator.amplitudes.len()];
4538 for &index in resources.active_indices() {
4539 active_mask[index] = true;
4540 }
4541 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
4542 .map(|_| DVector::zeros(parameters.len()))
4543 .collect::<Vec<_>>();
4544 evaluator.fill_amplitude_gradients(
4545 &mut amplitude_gradients,
4546 &active_mask,
4547 ¶meters,
4548 &resources.caches[0],
4549 );
4550 let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4551 let mut ir_gradient_slots: Vec<DVector<Complex64>> =
4552 (0..evaluator.expression_ir().node_count())
4553 .map(|_| DVector::zeros(parameters.len()))
4554 .collect();
4555 let lowered_runtime = evaluator.lowered_runtime();
4556 let lowered_program = lowered_runtime.gradient_program();
4557 let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4558 let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
4559 .scratch_slots())
4560 .map(|_| DVector::zeros(parameters.len()))
4561 .collect();
4562 let active_gradient = evaluator.evaluate_expression_gradient_with_scratch(
4563 &litude_values,
4564 &litude_gradients,
4565 &mut ir_value_slots,
4566 &mut ir_gradient_slots,
4567 );
4568 let ir_gradient = evaluator.expression_ir().evaluate_gradient_into(
4569 &litude_values,
4570 &litude_gradients,
4571 &mut ir_value_slots,
4572 &mut ir_gradient_slots,
4573 );
4574 let lowered_gradient = lowered_program.evaluate_gradient_into(
4575 &litude_values,
4576 &litude_gradients,
4577 &mut lowered_value_slots,
4578 &mut lowered_gradient_slots,
4579 );
4580 for (active, lowered) in active_gradient.iter().zip(lowered_gradient.iter()) {
4581 assert_relative_eq!(active.re, lowered.re);
4582 assert_relative_eq!(active.im, lowered.im);
4583 }
4584 for (lowered, ir) in lowered_gradient.iter().zip(ir_gradient.iter()) {
4585 assert_relative_eq!(lowered.re, ir.re);
4586 assert_relative_eq!(lowered.im, ir.im);
4587 }
4588 }
4589 #[test]
4590 fn test_expression_ir_value_gradient_matches_lowered_runtime() {
4591 let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4592 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4593 * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
4594 .norm_sqr();
4595 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4596 let evaluator = expr.load(&dataset).unwrap();
4597 let resources = evaluator.resources.read();
4598 let parameters = resources
4599 .parameter_map
4600 .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
4601 .expect("parameters should assemble");
4602 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4603 evaluator.fill_amplitude_values(
4604 &mut amplitude_values,
4605 resources.active_indices(),
4606 ¶meters,
4607 &resources.caches[0],
4608 );
4609 let mut active_mask = vec![false; evaluator.amplitudes.len()];
4610 for &index in resources.active_indices() {
4611 active_mask[index] = true;
4612 }
4613 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
4614 .map(|_| DVector::zeros(parameters.len()))
4615 .collect::<Vec<_>>();
4616 evaluator.fill_amplitude_gradients(
4617 &mut amplitude_gradients,
4618 &active_mask,
4619 ¶meters,
4620 &resources.caches[0],
4621 );
4622 let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4623 let mut ir_gradient_slots: Vec<DVector<Complex64>> =
4624 (0..evaluator.expression_ir().node_count())
4625 .map(|_| DVector::zeros(parameters.len()))
4626 .collect();
4627 let lowered_runtime = evaluator.lowered_runtime();
4628 let lowered_program = lowered_runtime.value_gradient_program();
4629 let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4630 let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
4631 .scratch_slots())
4632 .map(|_| DVector::zeros(parameters.len()))
4633 .collect();
4634
4635 let active_value_gradient = evaluator.evaluate_expression_value_gradient_with_scratch(
4636 &litude_values,
4637 &litude_gradients,
4638 &mut ir_value_slots,
4639 &mut ir_gradient_slots,
4640 );
4641 let ir_value_gradient = evaluator.expression_ir().evaluate_value_gradient_into(
4642 &litude_values,
4643 &litude_gradients,
4644 &mut ir_value_slots,
4645 &mut ir_gradient_slots,
4646 );
4647 let lowered_value_gradient = lowered_program.evaluate_value_gradient_into(
4648 &litude_values,
4649 &litude_gradients,
4650 &mut lowered_value_slots,
4651 &mut lowered_gradient_slots,
4652 );
4653
4654 assert_relative_eq!(active_value_gradient.0.re, lowered_value_gradient.0.re);
4655 assert_relative_eq!(active_value_gradient.0.im, lowered_value_gradient.0.im);
4656 for (active, lowered) in active_value_gradient
4657 .1
4658 .iter()
4659 .zip(lowered_value_gradient.1.iter())
4660 {
4661 assert_relative_eq!(active.re, lowered.re);
4662 assert_relative_eq!(active.im, lowered.im);
4663 }
4664 assert_relative_eq!(lowered_value_gradient.0.re, ir_value_gradient.0.re);
4665 assert_relative_eq!(lowered_value_gradient.0.im, ir_value_gradient.0.im);
4666 for (lowered, ir) in lowered_value_gradient
4667 .1
4668 .iter()
4669 .zip(ir_value_gradient.1.iter())
4670 {
4671 assert_relative_eq!(lowered.re, ir.re);
4672 assert_relative_eq!(lowered.im, ir.im);
4673 }
4674 }
4675 #[test]
4676 fn test_expression_runtime_diagnostics_reports_lowered_programs() {
4677 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4678 let evaluator = fixture
4679 .expression
4680 .load(&fixture.dataset)
4681 .expect("fixture evaluator should load");
4682
4683 let diagnostics = evaluator.expression_runtime_diagnostics();
4684 assert!(diagnostics.ir_planning_enabled);
4685 assert!(diagnostics.lowered_value_program_present);
4686 assert!(diagnostics.lowered_gradient_program_present);
4687 assert!(diagnostics.lowered_value_gradient_program_present);
4688 assert!(diagnostics.residual_runtime_present);
4689 assert_eq!(
4690 diagnostics.specialization_status,
4691 Some(ExpressionSpecializationStatus {
4692 origin: ExpressionSpecializationOrigin::InitialLoad,
4693 })
4694 );
4695 }
4696 #[test]
4697 fn test_expression_runtime_diagnostics_reports_specialization_origin() {
4698 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4699 let evaluator = fixture
4700 .expression
4701 .load(&fixture.dataset)
4702 .expect("fixture evaluator should load");
4703
4704 assert_eq!(
4705 evaluator
4706 .expression_runtime_diagnostics()
4707 .specialization_status,
4708 Some(ExpressionSpecializationStatus {
4709 origin: ExpressionSpecializationOrigin::InitialLoad,
4710 })
4711 );
4712
4713 evaluator.isolate_many(&["p"]);
4714 assert_eq!(
4715 evaluator
4716 .expression_runtime_diagnostics()
4717 .specialization_status,
4718 Some(ExpressionSpecializationStatus {
4719 origin: ExpressionSpecializationOrigin::CacheMissRebuild,
4720 })
4721 );
4722 }
4723 #[test]
4724 fn test_compiled_expression_display_reports_dag_refs() {
4725 let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4726 let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4727 let term = &a * &b;
4728 let expr = &term + &term;
4729 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4730 let evaluator = expr.load(&dataset).unwrap();
4731
4732 let compiled = evaluator.compiled_expression();
4733 let display = compiled.to_string();
4734
4735 assert_eq!(compiled.root(), compiled.nodes().len() - 1);
4736 assert!(display.contains("#"));
4737 assert!(display.contains("+"));
4738 assert!(display.contains("×"));
4739 assert!(display.contains("a(id=0)"));
4740 assert!(display.contains("b(id=1)"));
4741 assert!(display.contains("(ref)"));
4742 }
4743
4744 #[test]
4745 fn test_expression_compiled_expression_display_reports_dag_refs_without_loading() {
4746 let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4747 let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4748 let term = &a * &b;
4749 let expr = &term + &term;
4750
4751 let compiled = expr.compiled_expression();
4752 let display = compiled.to_string();
4753
4754 assert_eq!(compiled.root(), compiled.nodes().len() - 1);
4755 assert!(display.contains("#"));
4756 assert!(display.contains("+"));
4757 assert!(display.contains("×"));
4758 assert!(display.contains("a(id=0)"));
4759 assert!(display.contains("b(id=1)"));
4760 assert!(display.contains("(ref)"));
4761 }
4762
4763 #[test]
4764 fn test_compiled_expression_display_uses_current_active_mask() {
4765 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4766 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4767 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4768 let evaluator = expr.load(&dataset).unwrap();
4769 evaluator.deactivate("b");
4770
4771 let compiled = evaluator.compiled_expression().to_string();
4772
4773 assert!(compiled.contains("a(id=0)"));
4774 assert!(!compiled.contains("b(id=1)"));
4775 assert!(!compiled.contains("const 0"));
4776 assert!(!compiled.contains("+"));
4777 }
4778
4779 fn assert_compiled_single_amplitude(expr: &Expression, expected_label: &str) {
4780 let compiled = expr.compiled_expression();
4781 assert_eq!(compiled.nodes().len(), 1);
4782 assert_eq!(compiled.root(), 0);
4783 match &compiled.nodes()[0] {
4784 CompiledExpressionNode::Amplitude { index, name } => {
4785 assert_eq!(*index, 0);
4786 assert_eq!(name, expected_label);
4787 }
4788 node => panic!("expected one amplitude node, got {node:?}"),
4789 }
4790 }
4791
4792 fn assert_compiled_constant(expr: &Expression, expected: Complex64) {
4793 let compiled = expr.compiled_expression();
4794 assert_eq!(compiled.nodes().len(), 1);
4795 assert_eq!(compiled.root(), 0);
4796 match compiled.nodes()[0] {
4797 CompiledExpressionNode::Constant(value) => {
4798 assert_relative_eq!(value.re, expected.re);
4799 assert_relative_eq!(value.im, expected.im);
4800 }
4801 ref node => panic!("expected one constant node, got {node:?}"),
4802 }
4803 }
4804
4805 #[test]
4806 fn test_compiled_expression_simplifies_arithmetic_identities() {
4807 let amp = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4808 let zero = Expression::zero();
4809 let one = Expression::one();
4810
4811 assert_compiled_single_amplitude(&(& + &zero), "a");
4812 assert_compiled_single_amplitude(&(&zero + &), "a");
4813 assert_compiled_single_amplitude(&(& - &zero), "a");
4814 assert_compiled_single_amplitude(&(& * &one), "a");
4815 assert_compiled_single_amplitude(&(&one * &), "a");
4816 assert_compiled_single_amplitude(&(& / &one), "a");
4817 assert_compiled_single_amplitude(&.pow(&one), "a");
4818 assert_compiled_single_amplitude(&.powi(1), "a");
4819 assert_compiled_single_amplitude(&.powf(1.0), "a");
4820
4821 let times_zero = & * &zero;
4822 assert_compiled_constant(×_zero, Complex64::ZERO);
4823 assert!(times_zero.parameters().contains_key("ar"));
4824 assert!(times_zero.parameters().contains_key("ai"));
4825
4826 assert_compiled_constant(&(&zero * &), Complex64::ZERO);
4827 assert_compiled_constant(&(&zero / &Expression::from(2.0)), Complex64::ZERO);
4828 assert_compiled_constant(&.powi(0), Complex64::ONE);
4829 assert_compiled_constant(
4830 &Expression::from(2.0).pow(&Expression::zero()),
4831 Complex64::ONE,
4832 );
4833 assert_compiled_constant(&Expression::from(2.0).powf(0.0), Complex64::ONE);
4834
4835 let unsafe_zero_division = (&zero / &).compiled_expression().to_string();
4836 assert!(unsafe_zero_division.contains("÷"));
4837 assert!(unsafe_zero_division.contains("a(id=0)"));
4838 }
4839
4840 #[test]
4841 fn test_compiled_expression_folds_unary_constant_functions() {
4842 assert_compiled_constant(&Expression::from(0.0).exp(), Complex64::ONE);
4843 assert_compiled_constant(&Expression::from(0.0).sin(), Complex64::ZERO);
4844 assert_compiled_constant(&Expression::from(0.0).cos(), Complex64::ONE);
4845 assert_compiled_constant(&Expression::from(1.0).log(), Complex64::ZERO);
4846 assert_compiled_constant(&Expression::from(4.0).sqrt(), Complex64::new(2.0, 0.0));
4847 assert_compiled_constant(&Expression::from(0.0).cis(), Complex64::ONE);
4848 }
4849
4850 #[test]
4851 fn test_evaluator_expression_reconstructs_expression() {
4852 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4853 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4854 let evaluator = expr.load(&dataset).unwrap();
4855
4856 assert_eq!(
4857 evaluator.expression().compiled_expression(),
4858 expr.compiled_expression()
4859 );
4860 }
4861
4862 #[test]
4863 fn test_active_mask_override_ignores_current_ir_specialization() {
4864 let expr = ComplexScalar::new("amp", parameter!("scale"), parameter!("amp_im", 0.0))
4865 .unwrap()
4866 .norm_sqr();
4867 let dataset = Arc::new(test_dataset());
4868 let evaluator = expr.load(&dataset).unwrap();
4869 let params = vec![2.0];
4870
4871 evaluator.deactivate("amp");
4872 assert_eq!(
4873 evaluator
4874 .evaluate(¶ms)
4875 .expect("evaluation should succeed")[0],
4876 Complex64::new(0.0, 0.0)
4877 );
4878
4879 let overridden = evaluator
4880 .evaluate_local_with_active_mask(¶ms, &[true])
4881 .unwrap();
4882 assert_eq!(overridden[0], Complex64::new(4.0, 0.0));
4883
4884 let overridden_fused = evaluator
4885 .evaluate_with_gradient_local_with_active_mask(¶ms, &[true])
4886 .unwrap();
4887 assert_eq!(overridden_fused[0].0, Complex64::new(4.0, 0.0));
4888 assert_eq!(overridden_fused[0].1[0], Complex64::new(4.0, 0.0));
4889 }
4890 #[test]
4891 fn test_expression_ir_dependence_diagnostics_surface() {
4892 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4893 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4894 .norm_sqr();
4895 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4896 let evaluator = expr.load(&dataset).unwrap();
4897 let annotations = evaluator
4898 .expression_node_dependence_annotations()
4899 .expect("annotations should exist");
4900 assert_eq!(annotations.len(), evaluator.expression_ir().node_count());
4901 assert!(annotations
4902 .iter()
4903 .all(|dependence| *dependence == ExpressionDependence::Mixed));
4904 assert_eq!(
4905 evaluator
4906 .expression_root_dependence()
4907 .expect("root dependence should exist"),
4908 ExpressionDependence::Mixed
4909 );
4910 }
4911 #[test]
4912 fn test_expression_ir_default_dependence_hint_is_mixed() {
4913 let expr = ComplexScalar::new("c", parameter!("cr"), parameter!("ci")).unwrap();
4914 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4915 let evaluator = expr.load(&dataset).unwrap();
4916 assert_eq!(
4917 evaluator
4918 .expression_root_dependence()
4919 .expect("root dependence should exist"),
4920 ExpressionDependence::Mixed
4921 );
4922 }
4923 #[test]
4924 fn test_expression_ir_parameter_only_dependence_hint_propagates() {
4925 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap();
4926 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4927 let evaluator = expr.load(&dataset).unwrap();
4928 assert_eq!(
4929 evaluator
4930 .expression_root_dependence()
4931 .expect("root dependence should exist"),
4932 ExpressionDependence::ParameterOnly
4933 );
4934 }
4935 #[test]
4936 fn test_expression_ir_cache_only_dependence_hint_propagates() {
4937 let expr = CacheOnlyScalar::new("k").unwrap();
4938 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4939 let evaluator = expr.load(&dataset).unwrap();
4940 assert_eq!(
4941 evaluator
4942 .expression_root_dependence()
4943 .expect("root dependence should exist"),
4944 ExpressionDependence::CacheOnly
4945 );
4946 }
4947 #[test]
4948 fn test_expression_ir_real_valued_hint_folds_imag_projection_to_zero() {
4949 let expr = ParameterOnlyScalar::new("p", parameter!("p"))
4950 .unwrap()
4951 .imag();
4952 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4953 let evaluator = expr.load(&dataset).unwrap();
4954 let ir = evaluator.expression_ir();
4955
4956 assert!(matches!(
4957 ir.nodes()[ir.root()],
4958 ir::IrNode::Constant(value) if value == Complex64::ZERO
4959 ));
4960 assert_eq!(
4961 evaluator
4962 .evaluate(&[2.5])
4963 .expect("evaluation should succeed")[0],
4964 Complex64::ZERO
4965 );
4966 }
4967 #[test]
4968 fn test_expression_ir_real_valued_hint_simplifies_conjugation() {
4969 let expr = ParameterOnlyScalar::new("p", parameter!("p"))
4970 .unwrap()
4971 .conj();
4972 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4973 let evaluator = expr.load(&dataset).unwrap();
4974 let ir = evaluator.expression_ir();
4975
4976 assert!(matches!(ir.nodes()[ir.root()], ir::IrNode::Amp(0)));
4977 assert_eq!(
4978 evaluator
4979 .evaluate(&[2.5])
4980 .expect("evaluation should succeed")[0],
4981 Complex64::new(2.5, 0.0)
4982 );
4983 }
4984 #[test]
4985 fn test_expression_ir_dependence_warnings_surface() {
4986 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
4987 + &CacheOnlyScalar::new("k").unwrap();
4988 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4989 let evaluator = expr.load(&dataset).unwrap();
4990 assert!(evaluator
4991 .expression_dependence_warnings()
4992 .expect("warnings should exist")
4993 .iter()
4994 .any(|warning| warning.contains("both ParameterOnly and CacheOnly")));
4995 }
4996 #[test]
4997 fn test_expression_ir_normalization_plan_explain_surface() {
4998 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
4999 * &CacheOnlyScalar::new("k").unwrap();
5000 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5001 let evaluator = expr.load(&dataset).unwrap();
5002 let explain = evaluator
5003 .expression_normalization_plan_explain()
5004 .expect("plan should exist");
5005 assert_eq!(explain.root_dependence, ExpressionDependence::Mixed);
5006 assert_eq!(explain.separable_mul_candidate_nodes.len(), 1);
5007 assert_eq!(
5008 explain.cached_separable_nodes,
5009 explain.separable_mul_candidate_nodes
5010 );
5011 assert!(explain.residual_terms.iter().all(|index| {
5012 !explain
5013 .separable_mul_candidate_nodes
5014 .iter()
5015 .any(|candidate| candidate == index)
5016 }));
5017 }
5018 #[test]
5019 fn test_expression_ir_normalization_execution_sets_surface() {
5020 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5021 * &CacheOnlyScalar::new("k").unwrap();
5022 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5023 let evaluator = expr.load(&dataset).unwrap();
5024 let sets = evaluator
5025 .expression_normalization_execution_sets()
5026 .expect("sets should exist");
5027 assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5028 assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5029 assert!(sets.residual_amplitudes.is_empty());
5030 }
5031 #[test]
5032 fn test_expression_ir_normalization_execution_sets_partial_surface() {
5033 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5034 * &CacheOnlyScalar::new("k").unwrap())
5035 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5036 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5037 let evaluator = expr.load(&dataset).unwrap();
5038 let sets = evaluator
5039 .expression_normalization_execution_sets()
5040 .expect("sets should exist");
5041 assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5042 assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5043 assert_eq!(sets.residual_amplitudes, vec![2]);
5044 }
5045 #[test]
5046 fn test_expression_ir_precomputed_cached_integrals_at_load() {
5047 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5048 * &CacheOnlyScalar::new("k").unwrap();
5049 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5050 let evaluator = expr.load(&dataset).unwrap();
5051 let precomputed = evaluator
5052 .expression_precomputed_cached_integrals()
5053 .expect("integrals should exist");
5054 assert_eq!(precomputed.len(), 1);
5055 let cache_reference = CacheOnlyScalar::new("k_ref")
5056 .unwrap()
5057 .load(&dataset)
5058 .unwrap();
5059 let cache_values = cache_reference
5060 .evaluate_local(&[])
5061 .expect("evaluation should succeed");
5062 let expected_weighted_sum = cache_values
5063 .iter()
5064 .zip(dataset.weights_local().iter())
5065 .fold(Complex64::ZERO, |acc, (value, event)| {
5066 acc + (*value * *event)
5067 });
5068 assert_relative_eq!(
5069 precomputed[0].weighted_cache_sum.re,
5070 expected_weighted_sum.re
5071 );
5072 assert_relative_eq!(
5073 precomputed[0].weighted_cache_sum.im,
5074 expected_weighted_sum.im
5075 );
5076 }
5077 #[test]
5078 fn test_expression_ir_precomputed_cached_integrals_empty_when_non_separable() {
5079 let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5080 * &CacheOnlyScalar::new("k").unwrap();
5081 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5082 let evaluator = expr.load(&dataset).unwrap();
5083 assert!(evaluator
5084 .expression_precomputed_cached_integrals()
5085 .expect("integrals should exist")
5086 .is_empty());
5087 }
5088 #[test]
5089 fn test_expression_ir_precomputed_cached_integrals_recompute_on_activation_change() {
5090 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5091 * &CacheOnlyScalar::new("k").unwrap();
5092 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5093 let evaluator = expr.load(&dataset).unwrap();
5094 assert_eq!(
5095 evaluator
5096 .expression_precomputed_cached_integrals()
5097 .expect("integrals should exist")
5098 .len(),
5099 1
5100 );
5101
5102 evaluator.isolate_many(&["p"]);
5103 assert!(evaluator
5104 .expression_precomputed_cached_integrals()
5105 .expect("integrals should exist")
5106 .is_empty());
5107 }
5108 #[test]
5109 fn test_expression_ir_precomputed_cached_integrals_recompute_on_dataset_change() {
5110 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5111 * &CacheOnlyScalar::new("k").unwrap();
5112 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5113 let mut evaluator = expr.load(&dataset).unwrap();
5114 drop(dataset);
5115 let before = evaluator
5116 .expression_precomputed_cached_integrals()
5117 .expect("integrals should exist");
5118 assert_eq!(before.len(), 1);
5119
5120 Arc::get_mut(&mut evaluator.dataset)
5121 .expect("evaluator should own dataset Arc in this test")
5122 .clear_events_local();
5123 let after = evaluator
5124 .expression_precomputed_cached_integrals()
5125 .expect("integrals should exist");
5126 assert_eq!(after.len(), 1);
5127 assert_eq!(after[0].weighted_cache_sum, Complex64::ZERO);
5128 assert!(before[0].weighted_cache_sum != after[0].weighted_cache_sum);
5129 }
5130 #[test]
5131 fn test_expression_ir_precomputed_cached_integral_gradient_terms_scale_by_cache_integrals() {
5132 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5133 * &CacheOnlyScalar::new("k").unwrap();
5134 let dataset = Arc::new(Dataset::new(vec![
5135 Arc::new(test_event()),
5136 Arc::new(test_event()),
5137 ]));
5138 let evaluator = expr.load(&dataset).unwrap();
5139 let cached_integrals = evaluator
5140 .expression_precomputed_cached_integrals()
5141 .expect("integrals should exist");
5142 assert_eq!(cached_integrals.len(), 1);
5143 let gradient_terms = evaluator
5144 .expression_precomputed_cached_integral_gradient_terms(&[1.25])
5145 .expect("evaluation should succeed");
5146 assert_eq!(gradient_terms.len(), 1);
5147 assert_eq!(gradient_terms[0].weighted_gradient.len(), 1);
5148 assert_relative_eq!(
5149 gradient_terms[0].weighted_gradient[0].re,
5150 cached_integrals[0].weighted_cache_sum.re,
5151 epsilon = 1e-6
5152 );
5153 assert_relative_eq!(
5154 gradient_terms[0].weighted_gradient[0].im,
5155 cached_integrals[0].weighted_cache_sum.im,
5156 epsilon = 1e-6
5157 );
5158 }
5159 #[test]
5160 fn test_expression_ir_precomputed_cached_integral_gradient_terms_empty_when_not_separable() {
5161 let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5162 * &CacheOnlyScalar::new("k").unwrap();
5163 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5164 let evaluator = expr.load(&dataset).unwrap();
5165 assert!(evaluator
5166 .expression_precomputed_cached_integral_gradient_terms(&[0.1, -0.2])
5167 .expect("evaluation should succeed")
5168 .is_empty());
5169 }
5170 #[test]
5171 fn test_expression_ir_lowered_cached_factor_programs_match_ir_cached_paths() {
5172 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5173 * &CacheOnlyScalar::new("k").unwrap())
5174 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5175 let dataset = Arc::new(test_dataset());
5176 let evaluator = expr.load(&dataset).unwrap();
5177 let resources = evaluator.resources.read();
5178 let state = evaluator
5179 .ensure_cached_integral_cache_state(&resources)
5180 .expect("state should be available");
5181 let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5182 let parameters = resources
5183 .parameter_map
5184 .assemble(&[0.55, 0.2, -0.15])
5185 .expect("parameters should assemble");
5186
5187 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5188 evaluator.fill_amplitude_values(
5189 &mut amplitude_values,
5190 &state.execution_sets.cached_parameter_amplitudes,
5191 ¶meters,
5192 &resources.caches[0],
5193 );
5194 let cached_value_ir =
5195 evaluator.evaluate_cached_weighted_value_sum_ir(&state, &litude_values);
5196 let cached_value_lowered = evaluator
5197 .evaluate_cached_weighted_value_sum_lowered(
5198 &state,
5199 lowered_artifacts.as_ref(),
5200 &litude_values,
5201 )
5202 .expect("cached value lowering should succeed");
5203 assert_relative_eq!(cached_value_lowered, cached_value_ir, epsilon = 1e-12);
5204
5205 let mut cached_parameter_mask = vec![false; evaluator.amplitudes.len()];
5206 for &index in &state.execution_sets.cached_parameter_amplitudes {
5207 cached_parameter_mask[index] = true;
5208 }
5209 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5210 .map(|_| DVector::zeros(parameters.len()))
5211 .collect::<Vec<_>>();
5212 evaluator.fill_amplitude_gradients(
5213 &mut amplitude_gradients,
5214 &cached_parameter_mask,
5215 ¶meters,
5216 &resources.caches[0],
5217 );
5218 let cached_gradient_ir = evaluator.evaluate_cached_weighted_gradient_sum_ir(
5219 &state,
5220 &litude_values,
5221 &litude_gradients,
5222 parameters.len(),
5223 );
5224 let cached_gradient_lowered = evaluator
5225 .evaluate_cached_weighted_gradient_sum_lowered(
5226 &state,
5227 lowered_artifacts.as_ref(),
5228 &litude_values,
5229 &litude_gradients,
5230 parameters.len(),
5231 )
5232 .expect("cached gradient lowering should succeed");
5233 for (lowered, ir) in cached_gradient_lowered
5234 .iter()
5235 .zip(cached_gradient_ir.iter())
5236 {
5237 assert_relative_eq!(*lowered, *ir, epsilon = 1e-12);
5238 }
5239 }
5240 #[test]
5241 fn test_expression_ir_lowered_residual_runtime_matches_zeroed_node_path() {
5242 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5243 * &CacheOnlyScalar::new("k").unwrap())
5244 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5245 let dataset = Arc::new(test_dataset());
5246 let evaluator = expr.load(&dataset).unwrap();
5247 let resources = evaluator.resources.read();
5248 let state = evaluator
5249 .ensure_cached_integral_cache_state(&resources)
5250 .expect("state should be available");
5251 let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5252 let parameters = resources
5253 .parameter_map
5254 .assemble(&[0.55, 0.2, -0.15])
5255 .expect("parameters should assemble");
5256
5257 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5258 evaluator.fill_amplitude_values(
5259 &mut amplitude_values,
5260 &state.execution_sets.residual_amplitudes,
5261 ¶meters,
5262 &resources.caches[0],
5263 );
5264 let residual_value_ir = evaluator.evaluate_residual_value_ir(&state, &litude_values);
5265 let residual_program = lowered_artifacts
5266 .residual_runtime
5267 .as_ref()
5268 .map(|runtime| runtime.value_program())
5269 .expect("residual value lowering should succeed");
5270 let mut value_slots = vec![Complex64::ZERO; residual_program.scratch_slots()];
5271 let residual_value_lowered =
5272 residual_program.evaluate_into(&litude_values, &mut value_slots);
5273 assert_relative_eq!(
5274 residual_value_lowered.re,
5275 residual_value_ir.re,
5276 epsilon = 1e-12
5277 );
5278 assert_relative_eq!(
5279 residual_value_lowered.im,
5280 residual_value_ir.im,
5281 epsilon = 1e-12
5282 );
5283
5284 let mut residual_active_mask = vec![false; evaluator.amplitudes.len()];
5285 for &index in &state.execution_sets.residual_amplitudes {
5286 residual_active_mask[index] = true;
5287 }
5288 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5289 .map(|_| DVector::zeros(parameters.len()))
5290 .collect::<Vec<_>>();
5291 evaluator.fill_amplitude_gradients(
5292 &mut amplitude_gradients,
5293 &residual_active_mask,
5294 ¶meters,
5295 &resources.caches[0],
5296 );
5297 let residual_gradient_ir = evaluator.evaluate_residual_gradient_ir(
5298 &state,
5299 &litude_values,
5300 &litude_gradients,
5301 parameters.len(),
5302 );
5303
5304 let program = lowered_artifacts
5305 .residual_runtime
5306 .as_ref()
5307 .map(|runtime| runtime.gradient_program())
5308 .expect("gradient lowering should succeed");
5309 let mut value_slots = vec![Complex64::ZERO; program.scratch_slots()];
5310 let mut gradient_slots = vec![Complex64::ZERO; program.scratch_slots() * parameters.len()];
5311 let residual_gradient_lowered = program.evaluate_gradient_into_flat(
5312 &litude_values,
5313 &litude_gradients,
5314 &mut value_slots,
5315 &mut gradient_slots,
5316 parameters.len(),
5317 );
5318
5319 for (lowered, ir) in residual_gradient_lowered
5320 .iter()
5321 .zip(residual_gradient_ir.iter())
5322 {
5323 assert_relative_eq!(lowered.re, ir.re, epsilon = 1e-12);
5324 assert_relative_eq!(lowered.im, ir.im, epsilon = 1e-12);
5325 }
5326 }
5327 #[test]
5328 fn test_expression_ir_reuses_lowered_artifacts_when_dataset_key_changes() {
5329 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5330 * &CacheOnlyScalar::new("k").unwrap())
5331 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5332 let dataset = Arc::new(test_dataset());
5333 let mut evaluator = expr.load(&dataset).unwrap();
5334 drop(dataset);
5335
5336 assert_eq!(evaluator.specialization_cache_len(), 1);
5337 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5338
5339 evaluator.reset_expression_compile_metrics();
5340 evaluator.reset_expression_specialization_metrics();
5341
5342 Arc::get_mut(&mut evaluator.dataset)
5343 .expect("evaluator should own dataset Arc in this test")
5344 .clear_events_local();
5345
5346 let cached_integrals = evaluator
5347 .expression_precomputed_cached_integrals()
5348 .expect("integrals should exist");
5349 assert_eq!(cached_integrals.len(), 1);
5350 assert_eq!(cached_integrals[0].weighted_cache_sum, Complex64::ZERO);
5351
5352 assert_eq!(evaluator.specialization_cache_len(), 2);
5353 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5354 assert_eq!(
5355 evaluator.expression_specialization_metrics(),
5356 ExpressionSpecializationMetrics {
5357 cache_hits: 0,
5358 cache_misses: 1,
5359 }
5360 );
5361
5362 let compile_metrics = evaluator.expression_compile_metrics();
5363 assert_eq!(compile_metrics.specialization_cache_hits, 0);
5364 assert_eq!(compile_metrics.specialization_cache_misses, 1);
5365 assert_eq!(compile_metrics.specialization_lowering_cache_hits, 1);
5366 assert_eq!(compile_metrics.specialization_lowering_cache_misses, 0);
5367 assert!(compile_metrics.specialization_ir_compile_nanos > 0);
5368 assert!(compile_metrics.specialization_cached_integrals_nanos > 0);
5369 assert_eq!(compile_metrics.specialization_lowering_nanos, 0);
5370 }
5371
5372 #[test]
5373 fn test_evaluate_weighted_gradient_sum_local_matches_eventwise_baseline() {
5374 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5375 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5376 let c1 = CacheOnlyScalar::new("c1").unwrap();
5377 let c2 = CacheOnlyScalar::new("c2").unwrap();
5378 let c3 = CacheOnlyScalar::new("c3").unwrap();
5379 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5380 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5381 let dataset = Arc::new(test_dataset());
5382 let evaluator = expr.load(&dataset).unwrap();
5383 assert_eq!(
5384 evaluator
5385 .expression_precomputed_cached_integrals()
5386 .expect("integrals should exist")
5387 .len(),
5388 2
5389 );
5390 let params = vec![0.2, -0.3, 1.1, -0.7];
5391 let expected = evaluator
5392 .evaluate_gradient_local(¶ms)
5393 .expect("evaluation should succeed")
5394 .iter()
5395 .zip(dataset.weights_local().iter())
5396 .fold(
5397 DVector::zeros(params.len()),
5398 |mut accum, (gradient, event)| {
5399 accum += gradient.map(|value| value.re).scale(*event);
5400 accum
5401 },
5402 );
5403 let actual = evaluator
5404 .evaluate_weighted_gradient_sum_local(¶ms)
5405 .expect("evaluation should succeed");
5406 for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
5407 assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
5408 }
5409 }
5410
5411 #[test]
5412 fn test_evaluate_weighted_value_sum_local_matches_eventwise_baseline() {
5413 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5414 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5415 let c1 = CacheOnlyScalar::new("c1").unwrap();
5416 let c2 = CacheOnlyScalar::new("c2").unwrap();
5417 let c3 = CacheOnlyScalar::new("c3").unwrap();
5418 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5419 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5420 let dataset = Arc::new(test_dataset());
5421 let evaluator = expr.load(&dataset).unwrap();
5422 assert_eq!(
5423 evaluator
5424 .expression_precomputed_cached_integrals()
5425 .expect("integrals should exist")
5426 .len(),
5427 2
5428 );
5429 let params = vec![0.2, -0.3, 1.1, -0.7];
5430 let expected = evaluator
5431 .evaluate_local(¶ms)
5432 .expect("evaluation should succeed")
5433 .iter()
5434 .zip(dataset.weights_local().iter())
5435 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5436 let actual = evaluator
5437 .evaluate_weighted_value_sum_local(¶ms)
5438 .expect("evaluation should succeed");
5439 assert_relative_eq!(actual, expected, epsilon = 1e-10);
5440 }
5441
5442 #[test]
5443 fn test_weighted_sums_match_hardcoded_reference_values() {
5444 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5445 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5446 let c1 = CacheOnlyScalar::new("c1").unwrap();
5447 let c2 = CacheOnlyScalar::new("c2").unwrap();
5448 let c3 = CacheOnlyScalar::new("c3").unwrap();
5449 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5450 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5451
5452 let metadata = Arc::new(DatasetMetadata::default());
5453 let dataset = Arc::new(Dataset::new_with_metadata(
5454 vec![
5455 Arc::new(EventData {
5456 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5457 aux: vec![],
5458 weight: 0.5,
5459 }),
5460 Arc::new(EventData {
5461 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5462 aux: vec![],
5463 weight: -1.25,
5464 }),
5465 Arc::new(EventData {
5466 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5467 aux: vec![],
5468 weight: 2.0,
5469 }),
5470 ],
5471 metadata,
5472 ));
5473 let evaluator = expr.load(&dataset).unwrap();
5474 let params = vec![0.7, -1.1, 0.9, -0.4];
5475
5476 let weighted_value_sum = evaluator
5477 .evaluate_weighted_value_sum_local(¶ms)
5478 .expect("evaluation should succeed");
5479 assert_relative_eq!(weighted_value_sum, 22.7725, epsilon = 1e-12);
5480
5481 let weighted_gradient_sum = evaluator
5482 .evaluate_weighted_gradient_sum_local(¶ms)
5483 .expect("evaluation should succeed");
5484 let free_parameters = evaluator
5485 .parameters()
5486 .free()
5487 .names()
5488 .into_iter()
5489 .map(|name| name.to_string())
5490 .collect::<Vec<_>>();
5491 assert_eq!(free_parameters, vec!["p1", "p2", "m1r", "m1i"]);
5492 let expected_gradient = [43.925, 7.25, 28.525, 0.0];
5493 assert_eq!(weighted_gradient_sum.len(), expected_gradient.len());
5494 for (actual, expected) in weighted_gradient_sum.iter().zip(expected_gradient.iter()) {
5495 assert_relative_eq!(*actual, *expected, epsilon = 1e-9);
5496 }
5497 }
5498 #[test]
5499 fn test_evaluate_weighted_gradient_sum_local_respects_signed_cached_terms() {
5500 let expr = Expression::one()
5501 - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5502 * &CacheOnlyScalar::new("k").unwrap());
5503 let dataset = Arc::new(test_dataset());
5504 let evaluator = expr.load(&dataset).unwrap();
5505 assert_eq!(
5506 evaluator
5507 .expression_precomputed_cached_integrals()
5508 .expect("integrals should exist")
5509 .len(),
5510 1
5511 );
5512 assert_eq!(
5513 evaluator
5514 .expression_precomputed_cached_integrals()
5515 .expect("integrals should exist")[0]
5516 .coefficient,
5517 -1
5518 );
5519 let params = vec![0.75];
5520 let expected = evaluator
5521 .evaluate_gradient_local(¶ms)
5522 .expect("evaluation should succeed")
5523 .iter()
5524 .zip(dataset.weights_local().iter())
5525 .fold(
5526 DVector::zeros(params.len()),
5527 |mut accum, (gradient, event)| {
5528 accum += gradient.map(|value| value.re).scale(*event);
5529 accum
5530 },
5531 );
5532 let actual = evaluator
5533 .evaluate_weighted_gradient_sum_local(¶ms)
5534 .expect("evaluation should succeed");
5535 for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
5536 assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
5537 }
5538 }
5539 #[test]
5540 fn test_evaluate_weighted_value_sum_local_respects_signed_cached_terms() {
5541 let expr = Expression::one()
5542 - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5543 * &CacheOnlyScalar::new("k").unwrap());
5544 let dataset = Arc::new(test_dataset());
5545 let evaluator = expr.load(&dataset).unwrap();
5546 assert_eq!(
5547 evaluator
5548 .expression_precomputed_cached_integrals()
5549 .expect("integrals should exist")
5550 .len(),
5551 1
5552 );
5553 assert_eq!(
5554 evaluator
5555 .expression_precomputed_cached_integrals()
5556 .expect("integrals should exist")[0]
5557 .coefficient,
5558 -1
5559 );
5560 let params = vec![0.75];
5561 let expected = evaluator
5562 .evaluate_local(¶ms)
5563 .expect("evaluation should succeed")
5564 .iter()
5565 .zip(dataset.weights_local().iter())
5566 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5567 let actual = evaluator
5568 .evaluate_weighted_value_sum_local(¶ms)
5569 .expect("evaluation should succeed");
5570 assert_relative_eq!(actual, expected, epsilon = 1e-10);
5571 }
5572 #[test]
5573 fn test_expression_ir_diagnostics_follow_activation_changes() {
5574 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5575 * &CacheOnlyScalar::new("k").unwrap())
5576 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5577 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5578 let evaluator = expr.load(&dataset).unwrap();
5579
5580 let all_active = evaluator
5581 .expression_normalization_plan_explain()
5582 .expect("plan should exist");
5583 assert_eq!(all_active.cached_separable_nodes.len(), 1);
5584 assert_eq!(
5585 evaluator
5586 .expression_root_dependence()
5587 .expect("root dependence should exist"),
5588 ExpressionDependence::Mixed
5589 );
5590
5591 evaluator.isolate_many(&["p"]);
5592 let param_only = evaluator
5593 .expression_normalization_plan_explain()
5594 .expect("plan should exist");
5595 assert!(param_only.cached_separable_nodes.is_empty());
5596 assert_eq!(
5597 evaluator
5598 .expression_root_dependence()
5599 .expect("root dependence should exist"),
5600 ExpressionDependence::ParameterOnly
5601 );
5602 }
5603 #[test]
5604 fn test_expression_ir_specialization_cache_reuses_prior_mask_specializations() {
5605 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5606 * &CacheOnlyScalar::new("k").unwrap())
5607 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5608 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5609 let evaluator = expr.load(&dataset).unwrap();
5610
5611 let initial_compile_metrics = evaluator.expression_compile_metrics();
5612 assert!(initial_compile_metrics.initial_ir_compile_nanos > 0);
5613 assert!(initial_compile_metrics.initial_cached_integrals_nanos > 0);
5614 assert!(initial_compile_metrics.initial_lowering_nanos > 0);
5615 assert_eq!(initial_compile_metrics.specialization_cache_hits, 0);
5616 assert_eq!(initial_compile_metrics.specialization_cache_misses, 0);
5617 assert_eq!(
5618 initial_compile_metrics.specialization_lowering_cache_hits,
5619 0
5620 );
5621 assert_eq!(
5622 initial_compile_metrics.specialization_lowering_cache_misses,
5623 1
5624 );
5625
5626 assert_eq!(evaluator.specialization_cache_len(), 1);
5627 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5628 assert_eq!(
5629 evaluator.expression_specialization_metrics(),
5630 ExpressionSpecializationMetrics {
5631 cache_hits: 0,
5632 cache_misses: 1,
5633 }
5634 );
5635 let all_active_cached_integrals = evaluator
5636 .expression_precomputed_cached_integrals()
5637 .expect("integrals should exist");
5638
5639 evaluator.isolate_many(&["p"]);
5640 assert_eq!(evaluator.specialization_cache_len(), 2);
5641 assert_eq!(
5642 evaluator.expression_specialization_metrics(),
5643 ExpressionSpecializationMetrics {
5644 cache_hits: 0,
5645 cache_misses: 2,
5646 }
5647 );
5648 let after_cache_miss_metrics = evaluator.expression_compile_metrics();
5649 assert_eq!(after_cache_miss_metrics.specialization_cache_hits, 0);
5650 assert_eq!(after_cache_miss_metrics.specialization_cache_misses, 1);
5651 assert_eq!(
5652 after_cache_miss_metrics.specialization_lowering_cache_hits,
5653 0
5654 );
5655 assert_eq!(
5656 after_cache_miss_metrics.specialization_lowering_cache_misses,
5657 2
5658 );
5659 assert!(after_cache_miss_metrics.specialization_ir_compile_nanos > 0);
5660 assert!(after_cache_miss_metrics.specialization_cached_integrals_nanos > 0);
5661 assert!(after_cache_miss_metrics.specialization_lowering_nanos > 0);
5662 assert!(evaluator
5663 .expression_precomputed_cached_integrals()
5664 .expect("integrals should exist")
5665 .is_empty());
5666
5667 evaluator.activate_many(&["k", "m"]);
5668 assert_eq!(evaluator.specialization_cache_len(), 2);
5669 assert_eq!(
5670 evaluator.expression_specialization_metrics(),
5671 ExpressionSpecializationMetrics {
5672 cache_hits: 1,
5673 cache_misses: 2,
5674 }
5675 );
5676 assert_eq!(
5677 evaluator
5678 .expression_precomputed_cached_integrals()
5679 .expect("integrals should exist"),
5680 all_active_cached_integrals
5681 );
5682 let after_cache_hit_metrics = evaluator.expression_compile_metrics();
5683 assert_eq!(after_cache_hit_metrics.specialization_cache_hits, 1);
5684 assert_eq!(after_cache_hit_metrics.specialization_cache_misses, 1);
5685 assert_eq!(
5686 after_cache_hit_metrics.specialization_lowering_cache_hits,
5687 0
5688 );
5689 assert_eq!(
5690 after_cache_hit_metrics.specialization_lowering_cache_misses,
5691 2
5692 );
5693 assert!(after_cache_hit_metrics.specialization_cache_restore_nanos > 0);
5694 }
5695
5696 #[test]
5697 fn test_weighted_sums_match_baseline_after_activation_changes() {
5698 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5699 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5700 let c1 = CacheOnlyScalar::new("c1").unwrap();
5701 let c2 = CacheOnlyScalar::new("c2").unwrap();
5702 let c3 = CacheOnlyScalar::new("c3").unwrap();
5703 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5704 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5705 let dataset = Arc::new(test_dataset());
5706 let evaluator = expr.load(&dataset).unwrap();
5707 let params = vec![0.2, -0.3, 1.1, -0.7];
5708
5709 evaluator.isolate_many(&["p1", "c1", "m1", "c3"]);
5710
5711 let expected_value = evaluator
5712 .evaluate_local(¶ms)
5713 .expect("evaluation should succeed")
5714 .iter()
5715 .zip(dataset.weights_local().iter())
5716 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5717 assert_relative_eq!(
5718 evaluator
5719 .evaluate_weighted_value_sum_local(¶ms)
5720 .expect("evaluation should succeed"),
5721 expected_value,
5722 epsilon = 1e-10
5723 );
5724 }
5725
5726 #[test]
5727 fn test_evaluate_local_does_not_depend_on_dataset_rows() {
5728 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5729 .unwrap()
5730 .norm_sqr();
5731 let mut event1 = test_event();
5732 event1.p4s[0].t = 7.5;
5733 let mut event2 = test_event();
5734 event2.p4s[0].t = 8.25;
5735 let mut event3 = test_event();
5736 event3.p4s[0].t = 9.0;
5737 let dataset = Arc::new(Dataset::new_with_metadata(
5738 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5739 Arc::new(DatasetMetadata::default()),
5740 ));
5741 let mut evaluator = expr.load(&dataset).unwrap();
5742 drop(dataset);
5743 let expected_len = evaluator.resources.read().caches.len();
5744 Arc::get_mut(&mut evaluator.dataset)
5745 .expect("evaluator should own dataset Arc in this test")
5746 .clear_events_local();
5747 let cached = evaluator
5748 .evaluate_local(&[1.25, -0.75])
5749 .expect("evaluation should succeed");
5750 assert_eq!(cached.len(), expected_len);
5751 }
5752
5753 #[test]
5754 fn test_evaluate_gradient_local_does_not_depend_on_dataset_rows() {
5755 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5756 .unwrap()
5757 .norm_sqr();
5758 let mut event1 = test_event();
5759 event1.p4s[0].t = 7.5;
5760 let mut event2 = test_event();
5761 event2.p4s[0].t = 8.25;
5762 let mut event3 = test_event();
5763 event3.p4s[0].t = 9.0;
5764 let dataset = Arc::new(Dataset::new_with_metadata(
5765 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5766 Arc::new(DatasetMetadata::default()),
5767 ));
5768 let mut evaluator = expr.load(&dataset).unwrap();
5769 drop(dataset);
5770 let expected_len = evaluator.resources.read().caches.len();
5771 Arc::get_mut(&mut evaluator.dataset)
5772 .expect("evaluator should own dataset Arc in this test")
5773 .clear_events_local();
5774 let cached = evaluator
5775 .evaluate_gradient_local(&[1.25, -0.75])
5776 .expect("evaluation should succeed");
5777 assert_eq!(cached.len(), expected_len);
5778 }
5779
5780 #[test]
5781 fn test_evaluate_with_gradient_local_matches_separate_paths() {
5782 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5783 .unwrap()
5784 .norm_sqr();
5785 let dataset = Arc::new(Dataset::new(vec![
5786 Arc::new(test_event()),
5787 Arc::new(test_event()),
5788 Arc::new(test_event()),
5789 ]));
5790 let evaluator = expr.load(&dataset).unwrap();
5791 let params = [1.25, -0.75];
5792 let values = evaluator
5793 .evaluate_local(¶ms)
5794 .expect("evaluation should succeed");
5795 let gradients = evaluator
5796 .evaluate_gradient_local(¶ms)
5797 .expect("evaluation should succeed");
5798 let fused = evaluator
5799 .evaluate_with_gradient_local(¶ms)
5800 .expect("evaluation should succeed");
5801 assert_eq!(fused.len(), values.len());
5802 assert_eq!(fused.len(), gradients.len());
5803 for ((value_gradient, value), gradient) in
5804 fused.iter().zip(values.iter()).zip(gradients.iter())
5805 {
5806 let (fused_value, fused_gradient) = value_gradient;
5807 assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
5808 assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
5809 assert_eq!(fused_gradient.len(), gradient.len());
5810 for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
5811 assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
5812 assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
5813 }
5814 }
5815 }
5816
5817 #[test]
5818 fn test_evaluate_with_gradient_batch_local_matches_separate_paths() {
5819 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5820 .unwrap()
5821 .norm_sqr();
5822 let dataset = Arc::new(Dataset::new(vec![
5823 Arc::new(test_event()),
5824 Arc::new(test_event()),
5825 Arc::new(test_event()),
5826 Arc::new(test_event()),
5827 ]));
5828 let evaluator = expr.load(&dataset).unwrap();
5829 let params = [0.5, -1.25];
5830 let indices = vec![0, 2, 3];
5831 let values = evaluator
5832 .evaluate_batch_local(¶ms, &indices)
5833 .expect("evaluation should succeed");
5834 let gradients = evaluator
5835 .evaluate_gradient_batch_local(¶ms, &indices)
5836 .expect("evaluation should succeed");
5837 let fused = evaluator
5838 .evaluate_with_gradient_batch_local(¶ms, &indices)
5839 .expect("evaluation should succeed");
5840 assert_eq!(fused.len(), values.len());
5841 assert_eq!(fused.len(), gradients.len());
5842 for ((value_gradient, value), gradient) in
5843 fused.iter().zip(values.iter()).zip(gradients.iter())
5844 {
5845 let (fused_value, fused_gradient) = value_gradient;
5846 assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
5847 assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
5848 assert_eq!(fused_gradient.len(), gradient.len());
5849 for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
5850 assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
5851 assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
5852 }
5853 }
5854 }
5855
5856 #[test]
5857 fn test_precompute_all_columnar_populates_cache() {
5858 let mut event1 = test_event();
5859 event1.p4s[0].t = 7.5;
5860 let mut event2 = test_event();
5861 event2.p4s[0].t = 8.25;
5862 let mut event3 = test_event();
5863 event3.p4s[0].t = 9.0;
5864 let dataset = Dataset::new_with_metadata(
5865 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5866 Arc::new(DatasetMetadata::default()),
5867 );
5868 let mut amplitude = TestAmplitude {
5869 tags: Tags::new(["test"]),
5870 re: parameter!("real"),
5871 pid_re: ParameterID::default(),
5872 im: parameter!("imag"),
5873 pid_im: ParameterID::default(),
5874 beam_energy: Default::default(),
5875 };
5876 let mut resources = Resources::default();
5877 amplitude
5878 .register(&mut resources)
5879 .expect("test amplitude should register");
5880 resources.reserve_cache(dataset.n_events());
5881 amplitude.precompute_all(&dataset, &mut resources);
5882 for cache in &resources.caches {
5883 assert!(cache.get_scalar(amplitude.beam_energy) > 0.0);
5884 }
5885 }
5886
5887 #[cfg(feature = "mpi")]
5888 #[mpi_test(np = [2])]
5889 fn test_load_reserves_local_cache_size_in_mpi() {
5890 use crate::mpi::{finalize_mpi, get_world, use_mpi};
5891
5892 use_mpi(true);
5893 assert!(get_world().is_some(), "MPI world should be initialized");
5894
5895 let expr = ComplexScalar::new(
5896 "constant",
5897 parameter!("const_re", 2.0),
5898 parameter!("const_im", 3.0),
5899 )
5900 .expect("constant amplitude should construct");
5901 let events = vec![
5902 Arc::new(test_event()),
5903 Arc::new(test_event()),
5904 Arc::new(test_event()),
5905 Arc::new(test_event()),
5906 ];
5907 let dataset = Arc::new(Dataset::new_with_metadata(
5908 events,
5909 Arc::new(DatasetMetadata::default()),
5910 ));
5911 let evaluator = expr.load(&dataset).expect("evaluator should load");
5912 let local_events = dataset.n_events_local();
5913 let cache_len = evaluator.resources.read().caches.len();
5914
5915 assert_eq!(
5916 cache_len, local_events,
5917 "cache length must match local event count under MPI"
5918 );
5919 finalize_mpi();
5920 }
5921
5922 #[cfg(feature = "mpi")]
5923 #[mpi_test(np = [2])]
5924 fn test_expression_ir_cached_integrals_are_rank_local_in_mpi() {
5925 use mpi::{collective::SystemOperation, topology::Communicator, traits::*};
5926
5927 use crate::mpi::{finalize_mpi, get_world, use_mpi};
5928
5929 use_mpi(true);
5930 let world = get_world().expect("MPI world should be initialized");
5931
5932 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5933 * &CacheOnlyScalar::new("k").unwrap();
5934 let events = vec![
5935 Arc::new(EventData {
5936 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
5937 aux: vec![],
5938 weight: 0.5,
5939 }),
5940 Arc::new(EventData {
5941 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5942 aux: vec![],
5943 weight: 1.0,
5944 }),
5945 Arc::new(EventData {
5946 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5947 aux: vec![],
5948 weight: 1.5,
5949 }),
5950 Arc::new(EventData {
5951 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
5952 aux: vec![],
5953 weight: 2.0,
5954 }),
5955 Arc::new(EventData {
5956 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5957 aux: vec![],
5958 weight: 2.5,
5959 }),
5960 Arc::new(EventData {
5961 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
5962 aux: vec![],
5963 weight: 3.0,
5964 }),
5965 Arc::new(EventData {
5966 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
5967 aux: vec![],
5968 weight: 3.5,
5969 }),
5970 Arc::new(EventData {
5971 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
5972 aux: vec![],
5973 weight: 4.0,
5974 }),
5975 ];
5976 let dataset = Arc::new(Dataset::new_with_metadata(
5977 events,
5978 Arc::new(DatasetMetadata::default()),
5979 ));
5980 let evaluator = expr.load(&dataset).expect("evaluator should load");
5981 let cached_integrals = evaluator
5982 .expression_precomputed_cached_integrals()
5983 .expect("integrals should exist");
5984 assert_eq!(cached_integrals.len(), 1);
5985
5986 let local_expected =
5987 dataset
5988 .weights_local()
5989 .iter()
5990 .enumerate()
5991 .fold(0.0, |acc, (index, weight)| {
5992 let event = dataset.event_local(index).expect("event should exist");
5993 acc + *weight * event.p4_at(0).e()
5994 });
5995 let cached_local = cached_integrals[0].weighted_cache_sum;
5996 assert_relative_eq!(cached_local.re, local_expected, epsilon = 1e-12);
5997 assert_relative_eq!(cached_local.im, 0.0, epsilon = 1e-12);
5998
5999 let weighted_value_sum = evaluator
6000 .evaluate_weighted_value_sum_local(&[2.0])
6001 .expect("evaluate should succeed");
6002 assert_relative_eq!(weighted_value_sum, 2.0 * local_expected, epsilon = 1e-10);
6003
6004 let mut global_expected = 0.0;
6005 world.all_reduce_into(
6006 &local_expected,
6007 &mut global_expected,
6008 SystemOperation::sum(),
6009 );
6010 if world.size() > 1 {
6011 assert!(
6012 (cached_local.re - global_expected).abs() > 1e-12,
6013 "cached integral should remain rank-local before MPI reduction"
6014 );
6015 }
6016 finalize_mpi();
6017 }
6018
6019 #[cfg(feature = "mpi")]
6020 #[mpi_test(np = [2])]
6021 fn test_expression_ir_weighted_sum_mpi_matches_global_eventwise_baseline() {
6022 use mpi::{collective::SystemOperation, traits::*};
6023
6024 use crate::mpi::{finalize_mpi, get_world, use_mpi};
6025
6026 use_mpi(true);
6027 let world = get_world().expect("MPI world should be initialized");
6028
6029 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6030 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6031 let c1 = CacheOnlyScalar::new("c1").unwrap();
6032 let c2 = CacheOnlyScalar::new("c2").unwrap();
6033 let c3 = CacheOnlyScalar::new("c3").unwrap();
6034 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6035 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6036 let events = vec![
6037 Arc::new(EventData {
6038 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
6039 aux: vec![],
6040 weight: 0.5,
6041 }),
6042 Arc::new(EventData {
6043 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6044 aux: vec![],
6045 weight: -1.25,
6046 }),
6047 Arc::new(EventData {
6048 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6049 aux: vec![],
6050 weight: 0.75,
6051 }),
6052 Arc::new(EventData {
6053 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
6054 aux: vec![],
6055 weight: 1.5,
6056 }),
6057 Arc::new(EventData {
6058 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6059 aux: vec![],
6060 weight: 2.25,
6061 }),
6062 Arc::new(EventData {
6063 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
6064 aux: vec![],
6065 weight: -0.5,
6066 }),
6067 Arc::new(EventData {
6068 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
6069 aux: vec![],
6070 weight: 3.5,
6071 }),
6072 Arc::new(EventData {
6073 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
6074 aux: vec![],
6075 weight: 1.25,
6076 }),
6077 ];
6078 let dataset = Arc::new(Dataset::new_with_metadata(
6079 events,
6080 Arc::new(DatasetMetadata::default()),
6081 ));
6082 let evaluator = expr.load(&dataset).expect("evaluator should load");
6083 let params = vec![0.2, -0.3, 1.1, -0.7];
6084
6085 let local_expected_value = evaluator
6086 .evaluate_local(¶ms)
6087 .expect("evaluate should succeed")
6088 .iter()
6089 .zip(dataset.weights_local().iter())
6090 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6091 let mut global_expected_value = 0.0;
6092 world.all_reduce_into(
6093 &local_expected_value,
6094 &mut global_expected_value,
6095 SystemOperation::sum(),
6096 );
6097 let mpi_value = evaluator
6098 .evaluate_weighted_value_sum_mpi(¶ms, &world)
6099 .expect("evaluate should succeed");
6100 assert_relative_eq!(mpi_value, global_expected_value, epsilon = 1e-10);
6101
6102 let local_expected_gradient = evaluator
6103 .evaluate_gradient_local(¶ms)
6104 .expect("evaluate should succeed")
6105 .iter()
6106 .zip(dataset.weights_local().iter())
6107 .fold(
6108 DVector::zeros(params.len()),
6109 |mut accum, (gradient, event)| {
6110 accum += gradient.map(|value| value.re).scale(*event);
6111 accum
6112 },
6113 );
6114 let mut global_expected_gradient = vec![0.0; local_expected_gradient.len()];
6115 world.all_reduce_into(
6116 local_expected_gradient.as_slice(),
6117 &mut global_expected_gradient,
6118 SystemOperation::sum(),
6119 );
6120 let mpi_gradient = evaluator
6121 .evaluate_weighted_gradient_sum_mpi(¶ms, &world)
6122 .expect("evaluate should succeed");
6123 for (actual, expected) in mpi_gradient.iter().zip(global_expected_gradient.iter()) {
6124 assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
6125 }
6126
6127 finalize_mpi();
6128 }
6129
6130 #[test]
6131 fn test_evaluate_local_succeeds_for_constant_amplitude() {
6132 let expr = ComplexScalar::new(
6133 "constant",
6134 parameter!("const_re", 2.0),
6135 parameter!("const_im", 3.0),
6136 )
6137 .unwrap();
6138 let dataset = Arc::new(Dataset::new_with_metadata(
6139 vec![Arc::new(test_event())],
6140 Arc::new(DatasetMetadata::default()),
6141 ));
6142 let evaluator = expr.load(&dataset).unwrap();
6143 let values = evaluator
6144 .evaluate_local(&[])
6145 .expect("evaluation should succeed");
6146 assert_eq!(values.len(), 1);
6147 let gradients = evaluator
6148 .evaluate_gradient_local(&[])
6149 .expect("evaluation should succeed");
6150 assert_eq!(gradients.len(), 1);
6151 }
6152
6153 #[test]
6154 fn test_constant_amplitude() {
6155 let expr = ComplexScalar::new(
6156 "constant",
6157 parameter!("const_re", 2.0),
6158 parameter!("const_im", 3.0),
6159 )
6160 .unwrap();
6161 let dataset = Arc::new(Dataset::new_with_metadata(
6162 vec![Arc::new(test_event())],
6163 Arc::new(DatasetMetadata::default()),
6164 ));
6165 let evaluator = expr.load(&dataset).unwrap();
6166 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6167 assert_eq!(result[0], Complex64::new(2.0, 3.0));
6168 }
6169
6170 #[test]
6171 fn test_parametric_amplitude() {
6172 let expr = ComplexScalar::new(
6173 "parametric",
6174 parameter!("test_param_re"),
6175 parameter!("test_param_im"),
6176 )
6177 .unwrap();
6178 let dataset = Arc::new(test_dataset());
6179 let evaluator = expr.load(&dataset).unwrap();
6180 let result = evaluator
6181 .evaluate(&[2.0, 3.0])
6182 .expect("evaluation should succeed");
6183 assert_eq!(result[0], Complex64::new(2.0, 3.0));
6184 }
6185
6186 #[test]
6187 fn test_expression_operations() {
6188 let expr1 = ComplexScalar::new(
6189 "const1",
6190 parameter!("const1_re", 2.0),
6191 parameter!("const1_im", 0.0),
6192 )
6193 .unwrap();
6194 let expr2 = ComplexScalar::new(
6195 "const2",
6196 parameter!("const2_re", 0.0),
6197 parameter!("const2_im", 1.0),
6198 )
6199 .unwrap();
6200 let expr3 = ComplexScalar::new(
6201 "const3",
6202 parameter!("const3_re", 3.0),
6203 parameter!("const3_im", 4.0),
6204 )
6205 .unwrap();
6206
6207 let dataset = Arc::new(test_dataset());
6208
6209 let expr_add = &expr1 + &expr2;
6211 let result_add = expr_add
6212 .load(&dataset)
6213 .unwrap()
6214 .evaluate(&[])
6215 .expect("evaluation should succeed");
6216 assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
6217
6218 let expr_sub = &expr1 - &expr2;
6220 let result_sub = expr_sub
6221 .load(&dataset)
6222 .unwrap()
6223 .evaluate(&[])
6224 .expect("evaluation should succeed");
6225 assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
6226
6227 let expr_mul = &expr1 * &expr2;
6229 let result_mul = expr_mul
6230 .load(&dataset)
6231 .unwrap()
6232 .evaluate(&[])
6233 .expect("evaluation should succeed");
6234 assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
6235
6236 let expr_div = &expr1 / &expr3;
6238 let result_div = expr_div
6239 .load(&dataset)
6240 .unwrap()
6241 .evaluate(&[])
6242 .expect("evaluation should succeed");
6243 assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
6244
6245 let expr_neg = -&expr3;
6247 let result_neg = expr_neg
6248 .load(&dataset)
6249 .unwrap()
6250 .evaluate(&[])
6251 .expect("evaluation should succeed");
6252 assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
6253
6254 let expr_add2 = &expr_add + &expr_mul;
6256 let result_add2 = expr_add2
6257 .load(&dataset)
6258 .unwrap()
6259 .evaluate(&[])
6260 .expect("evaluation should succeed");
6261 assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
6262
6263 let expr_sub2 = &expr_add - &expr_mul;
6265 let result_sub2 = expr_sub2
6266 .load(&dataset)
6267 .unwrap()
6268 .evaluate(&[])
6269 .expect("evaluation should succeed");
6270 assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
6271
6272 let expr_mul2 = &expr_add * &expr_mul;
6274 let result_mul2 = expr_mul2
6275 .load(&dataset)
6276 .unwrap()
6277 .evaluate(&[])
6278 .expect("evaluation should succeed");
6279 assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
6280
6281 let expr_div2 = &expr_add / &expr_add2;
6283 let result_div2 = expr_div2
6284 .load(&dataset)
6285 .unwrap()
6286 .evaluate(&[])
6287 .expect("evaluation should succeed");
6288 assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
6289
6290 let expr_neg2 = -&expr_mul2;
6292 let result_neg2 = expr_neg2
6293 .load(&dataset)
6294 .unwrap()
6295 .evaluate(&[])
6296 .expect("evaluation should succeed");
6297 assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
6298
6299 let expr_real = expr3.real();
6301 let result_real = expr_real
6302 .load(&dataset)
6303 .unwrap()
6304 .evaluate(&[])
6305 .expect("evaluation should succeed");
6306 assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
6307
6308 let expr_mul2_real = expr_mul2.real();
6310 let result_mul2_real = expr_mul2_real
6311 .load(&dataset)
6312 .unwrap()
6313 .evaluate(&[])
6314 .expect("evaluation should succeed");
6315 assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
6316
6317 let expr_imag = expr3.imag();
6319 let result_imag = expr_imag
6320 .load(&dataset)
6321 .unwrap()
6322 .evaluate(&[])
6323 .expect("evaluation should succeed");
6324 assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
6325
6326 let expr_mul2_imag = expr_mul2.imag();
6328 let result_mul2_imag = expr_mul2_imag
6329 .load(&dataset)
6330 .unwrap()
6331 .evaluate(&[])
6332 .expect("evaluation should succeed");
6333 assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
6334
6335 let expr_conj = expr3.conj();
6337 let result_conj = expr_conj
6338 .load(&dataset)
6339 .unwrap()
6340 .evaluate(&[])
6341 .expect("evaluation should succeed");
6342 assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
6343
6344 let expr_mul2_conj = expr_mul2.conj();
6346 let result_mul2_conj = expr_mul2_conj
6347 .load(&dataset)
6348 .unwrap()
6349 .evaluate(&[])
6350 .expect("evaluation should succeed");
6351 assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
6352
6353 let expr_norm = expr1.norm_sqr();
6355 let result_norm = expr_norm
6356 .load(&dataset)
6357 .unwrap()
6358 .evaluate(&[])
6359 .expect("evaluation should succeed");
6360 assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
6361
6362 let expr_mul2_norm = expr_mul2.norm_sqr();
6364 let result_mul2_norm = expr_mul2_norm
6365 .load(&dataset)
6366 .unwrap()
6367 .evaluate(&[])
6368 .expect("evaluation should succeed");
6369 assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
6370 }
6371
6372 #[test]
6373 fn test_amplitude_activation() {
6374 let expr1 = ComplexScalar::new(
6375 "const1",
6376 parameter!("const1_re_act", 1.0),
6377 parameter!("const1_im_act", 0.0),
6378 )
6379 .unwrap();
6380 let expr2 = ComplexScalar::new(
6381 "const2",
6382 parameter!("const2_re_act", 2.0),
6383 parameter!("const2_im_act", 0.0),
6384 )
6385 .unwrap();
6386
6387 let dataset = Arc::new(test_dataset());
6388 let expr = &expr1 + &expr2;
6389 let evaluator = expr.load(&dataset).unwrap();
6390
6391 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6393 assert_eq!(result[0], Complex64::new(3.0, 0.0));
6394
6395 evaluator.deactivate_strict("const1").unwrap();
6397 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6398 assert_eq!(result[0], Complex64::new(2.0, 0.0));
6399
6400 evaluator.isolate_strict("const1").unwrap();
6402 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6403 assert_eq!(result[0], Complex64::new(1.0, 0.0));
6404
6405 evaluator.activate_all();
6407 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6408 assert_eq!(result[0], Complex64::new(3.0, 0.0));
6409 }
6410
6411 #[test]
6412 fn test_gradient() {
6413 let expr1 = ComplexScalar::new(
6414 "parametric_1",
6415 parameter!("test_param_re_1"),
6416 parameter!("test_param_im_1"),
6417 )
6418 .unwrap();
6419 let expr2 = ComplexScalar::new(
6420 "parametric_2",
6421 parameter!("test_param_re_2"),
6422 parameter!("test_param_im_2"),
6423 )
6424 .unwrap();
6425
6426 let dataset = Arc::new(test_dataset());
6427 let params = vec![2.0, 3.0, 4.0, 5.0];
6428
6429 let expr = &expr1 + &expr2;
6430 let evaluator = expr.load(&dataset).unwrap();
6431
6432 let gradient = evaluator
6433 .evaluate_gradient(¶ms)
6434 .expect("evaluation should succeed");
6435
6436 assert_relative_eq!(gradient[0][0].re, 1.0);
6437 assert_relative_eq!(gradient[0][0].im, 0.0);
6438 assert_relative_eq!(gradient[0][1].re, 0.0);
6439 assert_relative_eq!(gradient[0][1].im, 1.0);
6440 assert_relative_eq!(gradient[0][2].re, 1.0);
6441 assert_relative_eq!(gradient[0][2].im, 0.0);
6442 assert_relative_eq!(gradient[0][3].re, 0.0);
6443 assert_relative_eq!(gradient[0][3].im, 1.0);
6444
6445 let expr = &expr1 - &expr2;
6446 let evaluator = expr.load(&dataset).unwrap();
6447
6448 let gradient = evaluator
6449 .evaluate_gradient(¶ms)
6450 .expect("evaluation should succeed");
6451
6452 assert_relative_eq!(gradient[0][0].re, 1.0);
6453 assert_relative_eq!(gradient[0][0].im, 0.0);
6454 assert_relative_eq!(gradient[0][1].re, 0.0);
6455 assert_relative_eq!(gradient[0][1].im, 1.0);
6456 assert_relative_eq!(gradient[0][2].re, -1.0);
6457 assert_relative_eq!(gradient[0][2].im, 0.0);
6458 assert_relative_eq!(gradient[0][3].re, 0.0);
6459 assert_relative_eq!(gradient[0][3].im, -1.0);
6460
6461 let expr = &expr1 * &expr2;
6462 let evaluator = expr.load(&dataset).unwrap();
6463
6464 let gradient = evaluator
6465 .evaluate_gradient(¶ms)
6466 .expect("evaluation should succeed");
6467
6468 assert_relative_eq!(gradient[0][0].re, 4.0);
6469 assert_relative_eq!(gradient[0][0].im, 5.0);
6470 assert_relative_eq!(gradient[0][1].re, -5.0);
6471 assert_relative_eq!(gradient[0][1].im, 4.0);
6472 assert_relative_eq!(gradient[0][2].re, 2.0);
6473 assert_relative_eq!(gradient[0][2].im, 3.0);
6474 assert_relative_eq!(gradient[0][3].re, -3.0);
6475 assert_relative_eq!(gradient[0][3].im, 2.0);
6476
6477 let expr = &expr1 / &expr2;
6478 let evaluator = expr.load(&dataset).unwrap();
6479
6480 let gradient = evaluator
6481 .evaluate_gradient(¶ms)
6482 .expect("evaluation should succeed");
6483
6484 assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
6485 assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
6486 assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
6487 assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
6488 assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
6489 assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
6490 assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
6491 assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
6492
6493 let expr = -(&expr1 * &expr2);
6494 let evaluator = expr.load(&dataset).unwrap();
6495
6496 let gradient = evaluator
6497 .evaluate_gradient(¶ms)
6498 .expect("evaluation should succeed");
6499
6500 assert_relative_eq!(gradient[0][0].re, -4.0);
6501 assert_relative_eq!(gradient[0][0].im, -5.0);
6502 assert_relative_eq!(gradient[0][1].re, 5.0);
6503 assert_relative_eq!(gradient[0][1].im, -4.0);
6504 assert_relative_eq!(gradient[0][2].re, -2.0);
6505 assert_relative_eq!(gradient[0][2].im, -3.0);
6506 assert_relative_eq!(gradient[0][3].re, 3.0);
6507 assert_relative_eq!(gradient[0][3].im, -2.0);
6508
6509 let expr = (&expr1 * &expr2).real();
6510 let evaluator = expr.load(&dataset).unwrap();
6511
6512 let gradient = evaluator
6513 .evaluate_gradient(¶ms)
6514 .expect("evaluation should succeed");
6515
6516 assert_relative_eq!(gradient[0][0].re, 4.0);
6517 assert_relative_eq!(gradient[0][0].im, 0.0);
6518 assert_relative_eq!(gradient[0][1].re, -5.0);
6519 assert_relative_eq!(gradient[0][1].im, 0.0);
6520 assert_relative_eq!(gradient[0][2].re, 2.0);
6521 assert_relative_eq!(gradient[0][2].im, 0.0);
6522 assert_relative_eq!(gradient[0][3].re, -3.0);
6523 assert_relative_eq!(gradient[0][3].im, 0.0);
6524
6525 let expr = (&expr1 * &expr2).imag();
6526 let evaluator = expr.load(&dataset).unwrap();
6527
6528 let gradient = evaluator
6529 .evaluate_gradient(¶ms)
6530 .expect("evaluation should succeed");
6531
6532 assert_relative_eq!(gradient[0][0].re, 5.0);
6533 assert_relative_eq!(gradient[0][0].im, 0.0);
6534 assert_relative_eq!(gradient[0][1].re, 4.0);
6535 assert_relative_eq!(gradient[0][1].im, 0.0);
6536 assert_relative_eq!(gradient[0][2].re, 3.0);
6537 assert_relative_eq!(gradient[0][2].im, 0.0);
6538 assert_relative_eq!(gradient[0][3].re, 2.0);
6539 assert_relative_eq!(gradient[0][3].im, 0.0);
6540
6541 let expr = (&expr1 * &expr2).conj();
6542 let evaluator = expr.load(&dataset).unwrap();
6543
6544 let gradient = evaluator
6545 .evaluate_gradient(¶ms)
6546 .expect("evaluation should succeed");
6547
6548 assert_relative_eq!(gradient[0][0].re, 4.0);
6549 assert_relative_eq!(gradient[0][0].im, -5.0);
6550 assert_relative_eq!(gradient[0][1].re, -5.0);
6551 assert_relative_eq!(gradient[0][1].im, -4.0);
6552 assert_relative_eq!(gradient[0][2].re, 2.0);
6553 assert_relative_eq!(gradient[0][2].im, -3.0);
6554 assert_relative_eq!(gradient[0][3].re, -3.0);
6555 assert_relative_eq!(gradient[0][3].im, -2.0);
6556
6557 let expr = (&expr1 * &expr2).norm_sqr();
6558 let evaluator = expr.load(&dataset).unwrap();
6559
6560 let gradient = evaluator
6561 .evaluate_gradient(¶ms)
6562 .expect("evaluation should succeed");
6563
6564 assert_relative_eq!(gradient[0][0].re, 164.0);
6565 assert_relative_eq!(gradient[0][0].im, 0.0);
6566 assert_relative_eq!(gradient[0][1].re, 246.0);
6567 assert_relative_eq!(gradient[0][1].im, 0.0);
6568 assert_relative_eq!(gradient[0][2].re, 104.0);
6569 assert_relative_eq!(gradient[0][2].im, 0.0);
6570 assert_relative_eq!(gradient[0][3].re, 130.0);
6571 assert_relative_eq!(gradient[0][3].im, 0.0);
6572 }
6573
6574 #[test]
6575 fn test_expression_function_gradients() {
6576 let expr1 = ComplexScalar::new(
6577 "function_parametric_1",
6578 parameter!("function_test_param_re_1"),
6579 parameter!("function_test_param_im_1"),
6580 )
6581 .unwrap();
6582 let expr2 = ComplexScalar::new(
6583 "function_parametric_2",
6584 parameter!("function_test_param_re_2"),
6585 parameter!("function_test_param_im_2"),
6586 )
6587 .unwrap();
6588
6589 let sin = expr1.sin();
6590 let cos = expr1.cos();
6591 let trig = &sin * &cos;
6592 let pow = expr1.pow(&expr2);
6593 let mut expr = expr1.sqrt();
6594 expr = &expr + &expr1.exp();
6595 expr = &expr + &expr1.powi(2);
6596 expr = &expr + &expr1.powf(1.7);
6597 expr = &expr + &trig;
6598 expr = &expr + &expr1.log();
6599 expr = &expr + &expr1.cis();
6600 expr = &expr + &pow;
6601
6602 let dataset = Arc::new(test_dataset());
6603 let evaluator = expr.load(&dataset).unwrap();
6604 let params = vec![2.0, 0.5, 1.2, -0.3];
6605 let gradient = evaluator
6606 .evaluate_gradient(¶ms)
6607 .expect("evaluation should succeed");
6608 let eps = 1e-6;
6609
6610 for param_index in 0..params.len() {
6611 let mut plus = params.clone();
6612 plus[param_index] += eps;
6613 let mut minus = params.clone();
6614 minus[param_index] -= eps;
6615 let finite_diff = (evaluator
6616 .evaluate(&plus)
6617 .expect("evaluation should succeed")[0]
6618 - evaluator
6619 .evaluate(&minus)
6620 .expect("evaluation should succeed")[0])
6621 / Complex64::new(2.0 * eps, 0.0);
6622
6623 assert_relative_eq!(
6624 gradient[0][param_index].re,
6625 finite_diff.re,
6626 epsilon = 1e-6,
6627 max_relative = 1e-6
6628 );
6629 assert_relative_eq!(
6630 gradient[0][param_index].im,
6631 finite_diff.im,
6632 epsilon = 1e-6,
6633 max_relative = 1e-6
6634 );
6635 }
6636 }
6637
6638 #[test]
6639 fn test_zeros_and_ones() {
6640 let amp = ComplexScalar::new(
6641 "parametric",
6642 parameter!("test_param_re"),
6643 parameter!("fixed_two", 2.0),
6644 )
6645 .unwrap();
6646 let dataset = Arc::new(test_dataset());
6647 let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
6648 let evaluator = expr.load(&dataset).unwrap();
6649
6650 let params = vec![2.0];
6651 let value = evaluator
6652 .evaluate(¶ms)
6653 .expect("evaluation should succeed");
6654 let gradient = evaluator
6655 .evaluate_gradient(¶ms)
6656 .expect("evaluation should succeed");
6657
6658 assert_relative_eq!(value[0].re, 8.0);
6660 assert_relative_eq!(value[0].im, 0.0);
6661
6662 assert_relative_eq!(gradient[0][0].re, 4.0);
6664 assert_relative_eq!(gradient[0][0].im, 0.0);
6665 }
6666 #[test]
6667 fn test_default_build_uses_lowered_expression_runtime() {
6668 let expr = ComplexScalar::new(
6669 "opt_in_gate",
6670 parameter!("opt_in_gate_re", 2.0),
6671 parameter!("opt_in_gate_im", 0.0),
6672 )
6673 .unwrap()
6674 .norm_sqr();
6675 let dataset = Arc::new(test_dataset());
6676 let evaluator = expr.load(&dataset).unwrap();
6677
6678 let diagnostics = evaluator.expression_runtime_diagnostics();
6679 assert!(diagnostics.ir_planning_enabled);
6680 assert!(diagnostics.lowered_value_program_present);
6681 assert!(diagnostics.lowered_gradient_program_present);
6682 assert!(diagnostics.lowered_value_gradient_program_present);
6683 assert_eq!(
6684 evaluator.evaluate(&[]).expect("evaluation should succeed")[0],
6685 Complex64::new(4.0, 0.0)
6686 );
6687 }
6688
6689 #[test]
6690 fn parameter_name_only_creates_free_parameter() {
6691 let p = parameter!("mass");
6692
6693 assert_eq!(p.name(), "mass");
6694 assert_eq!(p.fixed(), None);
6695 assert_eq!(p.initial(), None);
6696 assert_eq!(p.bounds(), (None, None));
6697 assert_eq!(p.unit(), None);
6698 assert_eq!(p.latex(), None);
6699 assert_eq!(p.description(), None);
6700 assert!(p.is_free());
6701 assert!(!p.is_fixed());
6702 }
6703
6704 #[test]
6705 fn parameter_name_and_value_creates_fixed_parameter() {
6706 let p = parameter!("width", 0.15);
6707
6708 assert_eq!(p.name(), "width");
6709 assert_eq!(p.fixed(), Some(0.15));
6710 assert_eq!(p.initial(), Some(0.15));
6711 assert!(p.is_fixed());
6712 assert!(!p.is_free());
6713 }
6714
6715 #[test]
6716 fn keyword_initial_sets_initial_only() {
6717 let p = parameter!("alpha", initial: 1.25);
6718
6719 assert_eq!(p.name(), "alpha");
6720 assert_eq!(p.fixed(), None);
6721 assert_eq!(p.initial(), Some(1.25));
6722 assert_eq!(p.bounds(), (None, None));
6723 assert!(p.is_free());
6724 }
6725
6726 #[test]
6727 fn keyword_fixed_sets_fixed_and_initial() {
6728 let p = parameter!("beta", fixed: 2.5);
6729
6730 assert_eq!(p.name(), "beta");
6731 assert_eq!(p.fixed(), Some(2.5));
6732 assert_eq!(p.initial(), Some(2.5));
6733 assert!(p.is_fixed());
6734 }
6735
6736 #[test]
6737 fn bounds_accept_plain_numbers() {
6738 let p = parameter!("x", bounds: (0.0, 10.0));
6739
6740 assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
6741 }
6742
6743 #[test]
6744 fn bounds_accept_none_and_number() {
6745 let p = parameter!("x", bounds: (None, 10.0));
6746
6747 assert_eq!(p.bounds(), (None, Some(10.0)));
6748 }
6749
6750 #[test]
6751 fn bounds_accept_number_and_none() {
6752 let p = parameter!("x", bounds: (-1.0, None));
6753
6754 assert_eq!(p.bounds(), (Some(-1.0), None));
6755 }
6756
6757 #[test]
6758 fn bounds_accept_both_none() {
6759 let p = parameter!("x", bounds: (None, None));
6760
6761 assert_eq!(p.bounds(), (None, None));
6762 }
6763
6764 #[test]
6765 fn bounds_accept_arbitrary_expressions() {
6766 let lo = 1.0;
6767 let hi = 2.0 * 3.0;
6768 let p = parameter!("x", bounds: (lo - 0.5, hi));
6769
6770 assert_eq!(p.bounds(), (Some(0.5), Some(6.0)));
6771 }
6772
6773 #[test]
6774 fn multiple_keyword_arguments_work_together() {
6775 let p = parameter!(
6776 "gamma",
6777 initial: 1.0,
6778 bounds: (0.0, 5.0),
6779 unit: "GeV",
6780 latex: r"\gamma",
6781 description: "test parameter",
6782 );
6783
6784 assert_eq!(p.name(), "gamma");
6785 assert_eq!(p.fixed(), None);
6786 assert_eq!(p.initial(), Some(1.0));
6787 assert_eq!(p.bounds(), (Some(0.0), Some(5.0)));
6788 assert_eq!(p.unit().as_deref(), Some("GeV"));
6789 assert_eq!(p.latex().as_deref(), Some(r"\gamma"));
6790 assert_eq!(p.description().as_deref(), Some("test parameter"));
6791 }
6792
6793 #[test]
6794 fn fixed_can_be_combined_with_other_fields() {
6795 let p = parameter!(
6796 "delta",
6797 fixed: 3.0,
6798 bounds: (0.0, 10.0),
6799 unit: "rad",
6800 );
6801
6802 assert_eq!(p.name(), "delta");
6803 assert_eq!(p.fixed(), Some(3.0));
6804 assert_eq!(p.initial(), Some(3.0));
6805 assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
6806 assert_eq!(p.unit().as_deref(), Some("rad"));
6807 }
6808
6809 #[test]
6810 fn trailing_comma_is_accepted() {
6811 let p = parameter!(
6812 "eps",
6813 initial: 0.5,
6814 bounds: (None, 1.0),
6815 unit: "arb",
6816 );
6817
6818 assert_eq!(p.initial(), Some(0.5));
6819 assert_eq!(p.bounds(), (None, Some(1.0)));
6820 assert_eq!(p.unit().as_deref(), Some("arb"));
6821 }
6822
6823 #[test]
6824 fn test_parameter_registration() {
6825 let expr = ComplexScalar::new(
6826 "parametric",
6827 parameter!("test_param_re"),
6828 parameter!("fixed_two", 2.0),
6829 )
6830 .unwrap();
6831 let parameters = expr.parameters().free().names();
6832 assert_eq!(parameters.len(), 1);
6833 assert_eq!(parameters[0], "test_param_re");
6834 }
6835
6836 #[test]
6837 fn test_duplicate_amplitude_tag_registration_is_allowed() {
6838 let amp1 = ComplexScalar::new(
6839 "same_name",
6840 parameter!("dup_re1", 1.0),
6841 parameter!("dup_im1", 0.0),
6842 )
6843 .unwrap();
6844 let amp2 = ComplexScalar::new(
6845 "same_name",
6846 parameter!("dup_re2", 2.0),
6847 parameter!("dup_im2", 0.0),
6848 )
6849 .unwrap();
6850 let expr = amp1 + amp2;
6851 assert_eq!(
6852 expr.parameters().fixed().names(),
6853 vec!["dup_re1", "dup_im1", "dup_re2", "dup_im2"]
6854 );
6855 }
6856
6857 #[test]
6858 fn test_tree_printing() {
6859 let amp1 = ComplexScalar::new(
6860 "parametric_1",
6861 parameter!("test_param_re_1"),
6862 parameter!("test_param_im_1"),
6863 )
6864 .unwrap();
6865 let amp2 = ComplexScalar::new(
6866 "parametric_2",
6867 parameter!("test_param_re_2"),
6868 parameter!("test_param_im_2"),
6869 )
6870 .unwrap();
6871 let expr =
6872 &1.real() + &2.conj().imag() + Expression::one() * Complex64::new(-1.4, 2.0)
6873 - Expression::zero() / 1.0
6874 + (&1 * &2).norm_sqr();
6875 assert_eq!(
6876 expr.to_string(),
6877 concat!(
6878 "+\n",
6879 "├─ -\n",
6880 "│ ├─ +\n",
6881 "│ │ ├─ +\n",
6882 "│ │ │ ├─ Re\n",
6883 "│ │ │ │ └─ parametric_1(id=0)\n",
6884 "│ │ │ └─ Im\n",
6885 "│ │ │ └─ *\n",
6886 "│ │ │ └─ parametric_2(id=1)\n",
6887 "│ │ └─ ×\n",
6888 "│ │ ├─ 1 (exact)\n",
6889 "│ │ └─ -1.4+2i\n",
6890 "│ └─ ÷\n",
6891 "│ ├─ 0 (exact)\n",
6892 "│ └─ 1 (exact)\n",
6893 "└─ NormSqr\n",
6894 " └─ ×\n",
6895 " ├─ parametric_1(id=0)\n",
6896 " └─ parametric_2(id=1)\n",
6897 )
6898 );
6899 }
6900}