1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 sync::{
5 atomic::{AtomicU64, Ordering},
6 Arc,
7 },
8 time::Instant,
9};
10
11use auto_ops::*;
12use nalgebra::DVector;
13use num::complex::Complex64;
14use parking_lot::RwLock;
15#[cfg(feature = "rayon")]
16use rayon::prelude::*;
17use serde::{Deserialize, Serialize};
18
19use super::{ir, lowered};
20
21static COMPUTE_AMPLITUDE_COUNTER: AtomicU64 = AtomicU64::new(0);
22static AMPLITUDE_USE_SITE_COUNTER: AtomicU64 = AtomicU64::new(0);
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
25struct ComputeAmplitudeId(u64);
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
28struct AmplitudeUseSiteId(u64);
29
30fn next_compute_amplitude_id() -> ComputeAmplitudeId {
31 ComputeAmplitudeId(COMPUTE_AMPLITUDE_COUNTER.fetch_add(1, Ordering::Relaxed))
32}
33
34fn next_amplitude_use_site_id() -> AmplitudeUseSiteId {
35 AmplitudeUseSiteId(AMPLITUDE_USE_SITE_COUNTER.fetch_add(1, Ordering::Relaxed))
36}
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum ExpressionDependence {
40 ParameterOnly,
42 CacheOnly,
44 Mixed,
46}
47impl From<ir::DependenceClass> for ExpressionDependence {
48 fn from(value: ir::DependenceClass) -> Self {
49 match value {
50 ir::DependenceClass::ParameterOnly => Self::ParameterOnly,
51 ir::DependenceClass::CacheOnly => Self::CacheOnly,
52 ir::DependenceClass::Mixed => Self::Mixed,
53 }
54 }
55}
56#[derive(Clone, Debug, PartialEq, Eq)]
57pub struct NormalizationPlanExplain {
59 pub root_dependence: ExpressionDependence,
61 pub warnings: Vec<String>,
63 pub separable_mul_candidate_nodes: Vec<usize>,
65 pub cached_separable_nodes: Vec<usize>,
67 pub residual_terms: Vec<usize>,
69}
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub struct NormalizationExecutionSetsExplain {
73 pub cached_parameter_amplitudes: Vec<usize>,
75 pub cached_cache_amplitudes: Vec<usize>,
77 pub residual_amplitudes: Vec<usize>,
79}
80#[derive(Clone, Debug, PartialEq)]
81pub struct PrecomputedCachedIntegral {
83 pub mul_node_index: usize,
85 pub parameter_node_index: usize,
87 pub cache_node_index: usize,
89 pub coefficient: i32,
91 pub weighted_cache_sum: Complex64,
93}
94#[derive(Clone, Debug, PartialEq)]
95pub struct PrecomputedCachedIntegralGradientTerm {
97 pub mul_node_index: usize,
99 pub parameter_node_index: usize,
101 pub cache_node_index: usize,
103 pub coefficient: i32,
105 pub weighted_gradient: DVector<Complex64>,
107}
108#[derive(Clone, Debug, PartialEq, Eq, Hash)]
109struct CachedIntegralCacheKey {
110 active_mask: Vec<bool>,
111 n_events_local: usize,
112 weights_local_len: usize,
113 weighted_sum_bits: u64,
114 weights_ptr: usize,
115}
116#[derive(Clone, Debug)]
117struct CachedIntegralCacheState {
118 key: CachedIntegralCacheKey,
119 expression_ir: ir::ExpressionIR,
120 values: Vec<PrecomputedCachedIntegral>,
121 execution_sets: ir::NormalizationExecutionSets,
122}
123#[derive(Clone, Debug)]
124struct LoweredArtifactCacheState {
125 parameter_node_indices: Vec<usize>,
126 mul_node_indices: Vec<usize>,
127 lowered_parameter_factors: Vec<Option<lowered::LoweredFactorRuntime>>,
128 residual_runtime: Option<lowered::LoweredExpressionRuntime>,
129 lowered_runtime: lowered::LoweredExpressionRuntime,
130}
131#[derive(Clone)]
132struct ExpressionSpecializationState {
133 cached_integrals: Arc<CachedIntegralCacheState>,
134 lowered_artifacts: Arc<LoweredArtifactCacheState>,
135}
136#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
137pub struct ExpressionSpecializationMetrics {
139 pub cache_hits: usize,
141 pub cache_misses: usize,
143}
144#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
145pub struct ExpressionCompileMetrics {
147 pub initial_ir_compile_nanos: u64,
149 pub initial_cached_integrals_nanos: u64,
151 pub initial_lowering_nanos: u64,
153 pub specialization_cache_hits: usize,
155 pub specialization_cache_misses: usize,
157 pub specialization_ir_compile_nanos: u64,
159 pub specialization_cached_integrals_nanos: u64,
161 pub specialization_lowering_nanos: u64,
163 pub specialization_lowering_cache_hits: usize,
165 pub specialization_lowering_cache_misses: usize,
167 pub specialization_cache_restore_nanos: u64,
169}
170impl From<ir::NormalizationPlanExplain> for NormalizationPlanExplain {
171 fn from(value: ir::NormalizationPlanExplain) -> Self {
172 Self {
173 root_dependence: value.root_dependence.into(),
174 warnings: value.warnings,
175 separable_mul_candidate_nodes: value
176 .separable_mul_candidates
177 .into_iter()
178 .map(|candidate| candidate.node_index)
179 .collect(),
180 cached_separable_nodes: value.cached_separable_nodes,
181 residual_terms: value.residual_terms,
182 }
183 }
184}
185impl From<ir::NormalizationExecutionSets> for NormalizationExecutionSetsExplain {
186 fn from(value: ir::NormalizationExecutionSets) -> Self {
187 Self {
188 cached_parameter_amplitudes: value.cached_parameter_amplitudes,
189 cached_cache_amplitudes: value.cached_cache_amplitudes,
190 residual_amplitudes: value.residual_amplitudes,
191 }
192 }
193}
194impl From<ExpressionDependence> for ir::DependenceClass {
195 fn from(value: ExpressionDependence) -> Self {
196 match value {
197 ExpressionDependence::ParameterOnly => Self::ParameterOnly,
198 ExpressionDependence::CacheOnly => Self::CacheOnly,
199 ExpressionDependence::Mixed => Self::Mixed,
200 }
201 }
202}
203
204#[cfg(feature = "mpi")]
205use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
206
207#[cfg(feature = "mpi")]
208use crate::mpi::LadduMPI;
209#[cfg(feature = "execution-context-prototype")]
210use crate::ExecutionContext;
211#[cfg(all(feature = "execution-context-prototype", feature = "rayon"))]
212use crate::ThreadPolicy;
213use crate::{
214 amplitude::{Amplitude, AmplitudeUseSite},
215 data::Dataset,
216 parameters::ParameterMap,
217 resources::{Cache, Parameters, Resources},
218 LadduError, LadduResult,
219};
220
221#[allow(missing_docs)]
223#[derive(Clone, Serialize, Deserialize, Default)]
224pub struct Expression {
225 registry: ExpressionRegistry,
226 tree: ExpressionNode,
227}
228
229#[derive(Clone, Serialize, Deserialize)]
230#[allow(missing_docs)]
231#[derive(Default)]
232pub struct ExpressionRegistry {
233 amplitudes: Vec<Box<dyn Amplitude>>,
234 amplitude_names: Vec<String>,
235 amplitude_ids: Vec<ComputeAmplitudeId>,
236 amplitude_use_sites: Vec<AmplitudeUseSite>,
237 amplitude_use_site_ids: Vec<AmplitudeUseSiteId>,
238 resources: Resources,
239}
240
241impl ExpressionRegistry {
242 fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
243 let mut resources = Resources::default();
244 let aid = amplitude.register(&mut resources)?;
245 let compute_id = next_compute_amplitude_id();
246 let use_site_id = next_amplitude_use_site_id();
247 resources.configure_amplitude_tags(std::slice::from_ref(&aid.0));
248 Ok(Self {
249 amplitudes: vec![amplitude],
250 amplitude_names: vec![aid.0.display_label()],
251 amplitude_ids: vec![compute_id],
252 amplitude_use_sites: vec![AmplitudeUseSite {
253 amplitude_index: 0,
254 tags: aid.0,
255 }],
256 amplitude_use_site_ids: vec![use_site_id],
257 resources,
258 })
259 }
260
261 fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
262 let mut resources = Resources::default();
263 let mut amplitudes = Vec::new();
264 let mut amplitude_ids = Vec::new();
265 let mut compute_id_to_index = HashMap::new();
266 let mut semantic_key_to_index = HashMap::new();
267
268 let mut left_compute_map = Vec::with_capacity(self.amplitudes.len());
269 for (amp_index, (amp, amp_id)) in
270 self.amplitudes.iter().zip(&self.amplitude_ids).enumerate()
271 {
272 let semantic_key = amp.semantic_key();
273 let mut cloned_amp = dyn_clone::clone_box(&**amp);
274 let aid = cloned_amp.register(&mut resources)?;
275 if let Some(key) = semantic_key.clone() {
276 semantic_key_to_index.insert(key, aid.1);
277 }
278 amplitudes.push(cloned_amp);
279 amplitude_ids.push(*amp_id);
280 compute_id_to_index.insert(*amp_id, aid.1);
281 left_compute_map.push(aid.1);
282 debug_assert_eq!(amp_index, aid.1);
283 }
284
285 let mut right_compute_map = Vec::with_capacity(other.amplitudes.len());
286 for (amp, amp_id) in other.amplitudes.iter().zip(&other.amplitude_ids) {
287 if let Some(existing) = compute_id_to_index.get(amp_id) {
288 right_compute_map.push(*existing);
289 continue;
290 }
291 let incoming_semantic_key = amp.semantic_key();
292 if let Some(existing) = incoming_semantic_key
293 .as_ref()
294 .and_then(|key| semantic_key_to_index.get(key))
295 {
296 right_compute_map.push(*existing);
297 continue;
298 }
299 let mut cloned_amp = dyn_clone::clone_box(&**amp);
300 let aid = cloned_amp.register(&mut resources)?;
301 if let Some(key) = incoming_semantic_key.clone() {
302 semantic_key_to_index.insert(key, aid.1);
303 }
304 amplitudes.push(cloned_amp);
305 amplitude_ids.push(*amp_id);
306 compute_id_to_index.insert(*amp_id, aid.1);
307 right_compute_map.push(aid.1);
308 }
309
310 let mut amplitude_use_sites = Vec::new();
311 let mut amplitude_use_site_ids = Vec::new();
312 let mut amplitude_names = Vec::new();
313 let mut use_site_id_to_index = HashMap::new();
314
315 let mut left_map = Vec::with_capacity(self.amplitude_use_sites.len());
316 for (use_site, use_site_id) in self
317 .amplitude_use_sites
318 .iter()
319 .zip(&self.amplitude_use_site_ids)
320 {
321 let mapped_index = left_compute_map[use_site.amplitude_index];
322 let new_index = amplitude_use_sites.len();
323 left_map.push(new_index);
324 use_site_id_to_index.insert(*use_site_id, new_index);
325 amplitude_use_site_ids.push(*use_site_id);
326 amplitude_names.push(use_site.tags.display_label());
327 amplitude_use_sites.push(AmplitudeUseSite {
328 amplitude_index: mapped_index,
329 tags: use_site.tags.clone(),
330 });
331 }
332
333 let mut right_map = Vec::with_capacity(other.amplitude_use_sites.len());
334 for (use_site, use_site_id) in other
335 .amplitude_use_sites
336 .iter()
337 .zip(&other.amplitude_use_site_ids)
338 {
339 if let Some(existing) = use_site_id_to_index.get(use_site_id) {
340 right_map.push(*existing);
341 continue;
342 }
343 let mapped_index = right_compute_map[use_site.amplitude_index];
344 let new_index = amplitude_use_sites.len();
345 right_map.push(new_index);
346 use_site_id_to_index.insert(*use_site_id, new_index);
347 amplitude_use_site_ids.push(*use_site_id);
348 amplitude_names.push(use_site.tags.display_label());
349 amplitude_use_sites.push(AmplitudeUseSite {
350 amplitude_index: mapped_index,
351 tags: use_site.tags.clone(),
352 });
353 }
354 let use_site_tags = amplitude_use_sites
355 .iter()
356 .map(|use_site| use_site.tags.clone())
357 .collect::<Vec<_>>();
358 resources.configure_amplitude_tags(&use_site_tags);
359
360 Ok((
361 Self {
362 amplitudes,
363 amplitude_names,
364 amplitude_ids,
365 amplitude_use_sites,
366 amplitude_use_site_ids,
367 resources,
368 },
369 left_map,
370 right_map,
371 ))
372 }
373}
374
375#[allow(missing_docs)]
377#[derive(Clone, Serialize, Deserialize, Default, Debug)]
378pub enum ExpressionNode {
379 #[default]
380 Zero,
382 One,
384 Constant(f64),
386 ComplexConstant(Complex64),
388 Amp(usize),
390 Add(Box<ExpressionNode>, Box<ExpressionNode>),
392 Sub(Box<ExpressionNode>, Box<ExpressionNode>),
394 Mul(Box<ExpressionNode>, Box<ExpressionNode>),
396 Div(Box<ExpressionNode>, Box<ExpressionNode>),
398 Neg(Box<ExpressionNode>),
400 Real(Box<ExpressionNode>),
402 Imag(Box<ExpressionNode>),
404 Conj(Box<ExpressionNode>),
406 NormSqr(Box<ExpressionNode>),
408 Sqrt(Box<ExpressionNode>),
409 Pow(Box<ExpressionNode>, Box<ExpressionNode>),
410 PowI(Box<ExpressionNode>, i32),
411 PowF(Box<ExpressionNode>, f64),
412 Exp(Box<ExpressionNode>),
413 Sin(Box<ExpressionNode>),
414 Cos(Box<ExpressionNode>),
415 Log(Box<ExpressionNode>),
416 Cis(Box<ExpressionNode>),
417}
418
419impl ExpressionNode {
420 fn remap(&self, mapping: &[usize]) -> Self {
421 match self {
422 Self::Amp(idx) => Self::Amp(mapping[*idx]),
423 Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
424 Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
425 Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
426 Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
427 Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
428 Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
429 Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
430 Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
431 Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
432 Self::Zero => Self::Zero,
433 Self::One => Self::One,
434 Self::Constant(v) => Self::Constant(*v),
435 Self::ComplexConstant(v) => Self::ComplexConstant(*v),
436 Self::Sqrt(a) => Self::Sqrt(Box::new(a.remap(mapping))),
437 Self::Pow(a, b) => Self::Pow(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
438 Self::PowI(a, power) => Self::PowI(Box::new(a.remap(mapping)), *power),
439 Self::PowF(a, power) => Self::PowF(Box::new(a.remap(mapping)), *power),
440 Self::Exp(a) => Self::Exp(Box::new(a.remap(mapping))),
441 Self::Sin(a) => Self::Sin(Box::new(a.remap(mapping))),
442 Self::Cos(a) => Self::Cos(Box::new(a.remap(mapping))),
443 Self::Log(a) => Self::Log(Box::new(a.remap(mapping))),
444 Self::Cis(a) => Self::Cis(Box::new(a.remap(mapping))),
445 }
446 }
447}
448
449impl From<f64> for Expression {
450 fn from(value: f64) -> Self {
451 if value == 0.0 {
452 Self {
453 registry: ExpressionRegistry::default(),
454 tree: ExpressionNode::Zero,
455 }
456 } else if value == 1.0 {
457 Self {
458 registry: ExpressionRegistry::default(),
459 tree: ExpressionNode::One,
460 }
461 } else {
462 Self {
463 registry: ExpressionRegistry::default(),
464 tree: ExpressionNode::Constant(value),
465 }
466 }
467 }
468}
469impl From<&f64> for Expression {
470 fn from(value: &f64) -> Self {
471 (*value).into()
472 }
473}
474impl From<Complex64> for Expression {
475 fn from(value: Complex64) -> Self {
476 if value == Complex64::ZERO {
477 Self {
478 registry: ExpressionRegistry::default(),
479 tree: ExpressionNode::Zero,
480 }
481 } else if value == Complex64::ONE {
482 Self {
483 registry: ExpressionRegistry::default(),
484 tree: ExpressionNode::One,
485 }
486 } else {
487 Self {
488 registry: ExpressionRegistry::default(),
489 tree: ExpressionNode::ComplexConstant(value),
490 }
491 }
492 }
493}
494impl From<&Complex64> for Expression {
495 fn from(value: &Complex64) -> Self {
496 (*value).into()
497 }
498}
499
500impl Expression {
501 pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
503 let registry = ExpressionRegistry::singleton(amplitude)?;
504 Ok(Self {
505 tree: ExpressionNode::Amp(0),
506 registry,
507 })
508 }
509
510 pub fn zero() -> Self {
512 Self {
513 registry: ExpressionRegistry::default(),
514 tree: ExpressionNode::Zero,
515 }
516 }
517
518 pub fn one() -> Self {
520 Self {
521 registry: ExpressionRegistry::default(),
522 tree: ExpressionNode::One,
523 }
524 }
525
526 fn binary_op(
527 a: &Expression,
528 b: &Expression,
529 build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
530 ) -> Expression {
531 let (registry, left_map, right_map) = a
532 .registry
533 .merge(&b.registry)
534 .expect("merging expression registries should not fail");
535 let left_tree = a.tree.remap(&left_map);
536 let right_tree = b.tree.remap(&right_map);
537 Expression {
538 registry,
539 tree: build(Box::new(left_tree), Box::new(right_tree)),
540 }
541 }
542
543 fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
544 Expression {
545 registry: a.registry.clone(),
546 tree: build(Box::new(a.tree.clone())),
547 }
548 }
549
550 pub fn parameters(&self) -> ParameterMap {
552 self.registry.resources.parameters()
553 }
554
555 pub fn n_free(&self) -> usize {
557 self.registry.resources.n_free_parameters()
558 }
559
560 pub fn n_fixed(&self) -> usize {
562 self.registry.resources.n_fixed_parameters()
563 }
564
565 pub fn n_parameters(&self) -> usize {
567 self.registry.resources.n_parameters()
568 }
569
570 pub fn compiled_expression(&self) -> CompiledExpression {
576 let active_amplitudes = vec![true; self.registry.amplitude_use_sites.len()];
577 let amplitude_dependencies = self
578 .registry
579 .amplitude_use_sites
580 .iter()
581 .map(|use_site| {
582 ir::DependenceClass::from(
583 self.registry.amplitudes[use_site.amplitude_index].dependence_hint(),
584 )
585 })
586 .collect::<Vec<_>>();
587 let amplitude_realness = self
588 .registry
589 .amplitude_use_sites
590 .iter()
591 .map(|use_site| self.registry.amplitudes[use_site.amplitude_index].real_valued_hint())
592 .collect::<Vec<_>>();
593 let expression_ir = ir::compile_expression_ir_with_real_hints(
594 &self.tree,
595 &active_amplitudes,
596 &litude_dependencies,
597 &litude_realness,
598 );
599 CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
600 }
601
602 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
604 self.registry.resources.fix_parameter(name, value)
605 }
606
607 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
609 self.registry.resources.free_parameter(name)
610 }
611
612 pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
614 self.registry.resources.rename_parameter(old, new)
615 }
616
617 pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
619 self.registry.resources.rename_parameters(mapping)
620 }
621
622 pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
624 let mut resources = self.registry.resources.clone();
625 let metadata = dataset.metadata();
626 resources.reserve_cache(dataset.n_events_local());
627 resources.refresh_active_indices();
628 let parameter_map = resources.parameter_map.clone();
629 let mut amplitudes: Vec<Box<dyn Amplitude>> = self
630 .registry
631 .amplitudes
632 .iter()
633 .map(|amp| dyn_clone::clone_box(&**amp))
634 .collect();
635 {
636 for amplitude in amplitudes.iter_mut() {
637 amplitude.bind(metadata)?;
638 amplitude.precompute_all(dataset, &mut resources);
639 }
640 }
641 let ir_compile_start = Instant::now();
642 let expression_ir = {
643 let mut active_amplitudes = vec![false; self.registry.amplitude_use_sites.len()];
644 for &index in resources.active_indices() {
645 active_amplitudes[index] = true;
646 }
647 let amplitude_dependencies = self
648 .registry
649 .amplitude_use_sites
650 .iter()
651 .map(|use_site| {
652 ir::DependenceClass::from(
653 amplitudes[use_site.amplitude_index].dependence_hint(),
654 )
655 })
656 .collect::<Vec<_>>();
657 let amplitude_realness = self
658 .registry
659 .amplitude_use_sites
660 .iter()
661 .map(|use_site| amplitudes[use_site.amplitude_index].real_valued_hint())
662 .collect::<Vec<_>>();
663 ir::compile_expression_ir_with_real_hints(
664 &self.tree,
665 &active_amplitudes,
666 &litude_dependencies,
667 &litude_realness,
668 )
669 };
670 let initial_ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
671 let cached_integrals_start = Instant::now();
672 let cached_integrals = Evaluator::precompute_cached_integrals_at_load(
673 &expression_ir,
674 &litudes,
675 &self.registry.amplitude_use_sites,
676 &resources,
677 dataset,
678 parameter_map.free().len(),
679 )?;
680 let initial_cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
681 let lowering_start = Instant::now();
682 let lowered_artifacts = Arc::new(Evaluator::lower_expression_runtime_artifacts(
683 &expression_ir,
684 &cached_integrals,
685 )?);
686 let initial_lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
687 let execution_sets = expression_ir.normalization_execution_sets().clone();
688 let cached_integral_key =
689 Evaluator::cached_integral_cache_key(resources.active.clone(), dataset);
690 let cached_integral_state = Arc::new(CachedIntegralCacheState {
691 key: cached_integral_key.clone(),
692 expression_ir,
693 values: cached_integrals,
694 execution_sets,
695 });
696 let specialization_state = ExpressionSpecializationState {
697 cached_integrals: cached_integral_state.clone(),
698 lowered_artifacts: lowered_artifacts.clone(),
699 };
700 let specialization_cache = HashMap::from([(cached_integral_key, specialization_state)]);
701 let lowered_artifact_cache =
702 HashMap::from([(resources.active.clone(), lowered_artifacts.clone())]);
703 Ok(Evaluator {
704 amplitudes,
705 amplitude_use_sites: self.registry.amplitude_use_sites.clone(),
706 resources: Arc::new(RwLock::new(resources)),
707 dataset: dataset.clone(),
708 expression: self.tree.clone(),
709 ir_planning: ExpressionIrPlanningState {
710 cached_integrals: Arc::new(RwLock::new(Some(cached_integral_state))),
711 specialization_cache: Arc::new(RwLock::new(specialization_cache)),
712 specialization_metrics: Arc::new(RwLock::new(ExpressionSpecializationMetrics {
713 cache_hits: 0,
714 cache_misses: 1,
715 })),
716 lowered_artifact_cache: Arc::new(RwLock::new(lowered_artifact_cache)),
717 active_lowered_artifacts: Arc::new(RwLock::new(Some(lowered_artifacts.clone()))),
718 specialization_status: Arc::new(RwLock::new(Some(
719 ExpressionSpecializationStatus {
720 origin: ExpressionSpecializationOrigin::InitialLoad,
721 },
722 ))),
723 compile_metrics: Arc::new(RwLock::new(ExpressionCompileMetrics {
724 initial_ir_compile_nanos,
725 initial_cached_integrals_nanos,
726 initial_lowering_nanos,
727 specialization_lowering_cache_misses: 1,
728 ..Default::default()
729 })),
730 },
731 registry: self.registry.clone(),
732 })
733 }
734
735 pub fn real(&self) -> Self {
737 Self::unary_op(self, ExpressionNode::Real)
738 }
739 pub fn imag(&self) -> Self {
741 Self::unary_op(self, ExpressionNode::Imag)
742 }
743 pub fn conj(&self) -> Self {
745 Self::unary_op(self, ExpressionNode::Conj)
746 }
747 pub fn norm_sqr(&self) -> Self {
749 Self::unary_op(self, ExpressionNode::NormSqr)
750 }
751 pub fn sqrt(&self) -> Self {
753 Self::unary_op(self, ExpressionNode::Sqrt)
754 }
755 pub fn pow(&self, power: &Expression) -> Self {
757 Self::binary_op(self, power, ExpressionNode::Pow)
758 }
759 pub fn powi(&self, power: i32) -> Self {
761 Self::unary_op(self, |input| ExpressionNode::PowI(input, power))
762 }
763 pub fn powf(&self, power: f64) -> Self {
765 Self::unary_op(self, |input| ExpressionNode::PowF(input, power))
766 }
767 pub fn exp(&self) -> Self {
769 Self::unary_op(self, ExpressionNode::Exp)
770 }
771 pub fn sin(&self) -> Self {
773 Self::unary_op(self, ExpressionNode::Sin)
774 }
775 pub fn cos(&self) -> Self {
777 Self::unary_op(self, ExpressionNode::Cos)
778 }
779 pub fn log(&self) -> Self {
781 Self::unary_op(self, ExpressionNode::Log)
782 }
783 pub fn cis(&self) -> Self {
785 Self::unary_op(self, ExpressionNode::Cis)
786 }
787
788 fn write_tree(
790 &self,
791 t: &ExpressionNode,
792 f: &mut std::fmt::Formatter<'_>,
793 parent_prefix: &str,
794 immediate_prefix: &str,
795 parent_suffix: &str,
796 ) -> std::fmt::Result {
797 let display_string = match t {
798 ExpressionNode::Amp(idx) => {
799 let name = self
800 .registry
801 .amplitude_names
802 .get(*idx)
803 .cloned()
804 .unwrap_or_else(|| "<unregistered>".to_string());
805 format!("{name}(id={idx})")
806 }
807 ExpressionNode::Add(_, _) => "+".to_string(),
808 ExpressionNode::Sub(_, _) => "-".to_string(),
809 ExpressionNode::Mul(_, _) => "×".to_string(),
810 ExpressionNode::Div(_, _) => "÷".to_string(),
811 ExpressionNode::Neg(_) => "-".to_string(),
812 ExpressionNode::Real(_) => "Re".to_string(),
813 ExpressionNode::Imag(_) => "Im".to_string(),
814 ExpressionNode::Conj(_) => "*".to_string(),
815 ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
816 ExpressionNode::Zero => "0 (exact)".to_string(),
817 ExpressionNode::One => "1 (exact)".to_string(),
818 ExpressionNode::Constant(v) => v.to_string(),
819 ExpressionNode::ComplexConstant(v) => v.to_string(),
820 ExpressionNode::Sqrt(_) => "Sqrt".to_string(),
821 ExpressionNode::Pow(_, _) => "Pow".to_string(),
822 ExpressionNode::PowI(_, power) => format!("PowI({power})"),
823 ExpressionNode::PowF(_, power) => format!("PowF({power})"),
824 ExpressionNode::Exp(_) => "Exp".to_string(),
825 ExpressionNode::Sin(_) => "Sin".to_string(),
826 ExpressionNode::Cos(_) => "Cos".to_string(),
827 ExpressionNode::Log(_) => "Log".to_string(),
828 ExpressionNode::Cis(_) => "Cis".to_string(),
829 };
830 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
831 match t {
832 ExpressionNode::Amp(_)
833 | ExpressionNode::Zero
834 | ExpressionNode::One
835 | ExpressionNode::Constant(_)
836 | ExpressionNode::ComplexConstant(_) => {}
837 ExpressionNode::Add(a, b)
838 | ExpressionNode::Sub(a, b)
839 | ExpressionNode::Mul(a, b)
840 | ExpressionNode::Div(a, b)
841 | ExpressionNode::Pow(a, b) => {
842 let terms = [a, b];
843 let mut it = terms.iter().peekable();
844 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
845 while let Some(child) = it.next() {
846 match it.peek() {
847 Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│ "),
848 None => self.write_tree(child, f, &child_prefix, "└─ ", " "),
849 }?;
850 }
851 }
852 ExpressionNode::Neg(a)
853 | ExpressionNode::Real(a)
854 | ExpressionNode::Imag(a)
855 | ExpressionNode::Conj(a)
856 | ExpressionNode::NormSqr(a)
857 | ExpressionNode::Sqrt(a)
858 | ExpressionNode::PowI(a, _)
859 | ExpressionNode::PowF(a, _)
860 | ExpressionNode::Exp(a)
861 | ExpressionNode::Sin(a)
862 | ExpressionNode::Cos(a)
863 | ExpressionNode::Log(a)
864 | ExpressionNode::Cis(a) => {
865 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
866 self.write_tree(a, f, &child_prefix, "└─ ", " ")?;
867 }
868 }
869 Ok(())
870 }
871}
872
873impl Debug for Expression {
874 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
875 self.write_tree(&self.tree, f, "", "", "")
876 }
877}
878
879impl Display for Expression {
880 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
881 self.write_tree(&self.tree, f, "", "", "")
882 }
883}
884
885#[rustfmt::skip]
886impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
887 Expression::binary_op(a, b, ExpressionNode::Add)
888});
889#[rustfmt::skip]
890impl_op_ex!(+ |a: &Expression, b: &f64| -> Expression {
891 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
892});
893#[rustfmt::skip]
894impl_op_ex!(+ |a: &f64, b: &Expression| -> Expression {
895 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
896});
897#[rustfmt::skip]
898impl_op_ex!(+ |a: &Expression, b: &Complex64| -> Expression {
899 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
900});
901#[rustfmt::skip]
902impl_op_ex!(+ |a: &Complex64, b: &Expression| -> Expression {
903 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
904});
905
906#[rustfmt::skip]
907impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
908 Expression::binary_op(a, b, ExpressionNode::Sub)
909});
910#[rustfmt::skip]
911impl_op_ex!(- |a: &Expression, b: &f64| -> Expression {
912 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
913});
914#[rustfmt::skip]
915impl_op_ex!(- |a: &f64, b: &Expression| -> Expression {
916 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
917});
918#[rustfmt::skip]
919impl_op_ex!(- |a: &Expression, b: &Complex64| -> Expression {
920 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
921});
922#[rustfmt::skip]
923impl_op_ex!(- |a: &Complex64, b: &Expression| -> Expression {
924 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
925});
926
927#[rustfmt::skip]
928impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
929 Expression::binary_op(a, b, ExpressionNode::Mul)
930});
931#[rustfmt::skip]
932impl_op_ex!(* |a: &Expression, b: &f64| -> Expression {
933 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
934});
935#[rustfmt::skip]
936impl_op_ex!(* |a: &f64, b: &Expression| -> Expression {
937 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
938});
939#[rustfmt::skip]
940impl_op_ex!(* |a: &Expression, b: &Complex64| -> Expression {
941 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
942});
943#[rustfmt::skip]
944impl_op_ex!(* |a: &Complex64, b: &Expression| -> Expression {
945 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
946});
947
948#[rustfmt::skip]
949impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
950 Expression::binary_op(a, b, ExpressionNode::Div)
951});
952#[rustfmt::skip]
953impl_op_ex!(/ |a: &Expression, b: &f64| -> Expression {
954 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
955});
956#[rustfmt::skip]
957impl_op_ex!(/ |a: &f64, b: &Expression| -> Expression {
958 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
959});
960#[rustfmt::skip]
961impl_op_ex!(/ |a: &Expression, b: &Complex64| -> Expression {
962 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
963});
964#[rustfmt::skip]
965impl_op_ex!(/ |a: &Complex64, b: &Expression| -> Expression {
966 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
967});
968
969#[rustfmt::skip]
970impl_op_ex!(- |a: &Expression| -> Expression {
971 Expression::unary_op(a, ExpressionNode::Neg)
972});
973#[derive(Clone, Debug)]
976#[doc(hidden)]
977pub struct ExpressionValueProgramSnapshot {
978 lowered_program: lowered::LoweredProgram,
979}
980
981#[derive(Clone, Debug, PartialEq)]
982pub enum CompiledExpressionNode {
984 Constant(Complex64),
986 Amplitude {
988 index: usize,
990 name: String,
992 },
993 Unary {
995 op: String,
997 input: usize,
999 },
1000 Binary {
1002 op: String,
1004 left: usize,
1006 right: usize,
1008 },
1009}
1010
1011#[derive(Clone, Debug, PartialEq)]
1012pub struct CompiledExpression {
1018 nodes: Vec<CompiledExpressionNode>,
1019 root: usize,
1020}
1021
1022impl CompiledExpression {
1023 fn from_ir(ir: &ir::ExpressionIR, amplitude_names: &[String]) -> Self {
1024 let nodes = ir
1025 .nodes()
1026 .iter()
1027 .map(|node| match node {
1028 ir::IrNode::Constant(value) => CompiledExpressionNode::Constant(*value),
1029 ir::IrNode::Amp(index) => CompiledExpressionNode::Amplitude {
1030 index: *index,
1031 name: amplitude_names
1032 .get(*index)
1033 .cloned()
1034 .unwrap_or_else(|| "<unregistered>".to_string()),
1035 },
1036 ir::IrNode::Unary { op, input } => CompiledExpressionNode::Unary {
1037 op: compiled_unary_op_label(*op),
1038 input: *input,
1039 },
1040 ir::IrNode::Binary { op, left, right } => CompiledExpressionNode::Binary {
1041 op: compiled_binary_op_label(*op),
1042 left: *left,
1043 right: *right,
1044 },
1045 })
1046 .collect();
1047 Self {
1048 nodes,
1049 root: ir.root(),
1050 }
1051 }
1052
1053 pub fn nodes(&self) -> &[CompiledExpressionNode] {
1055 &self.nodes
1056 }
1057
1058 pub fn root(&self) -> usize {
1060 self.root
1061 }
1062
1063 fn node_label(&self, index: usize) -> String {
1064 let Some(node) = self.nodes.get(index) else {
1065 return format!("#{index} <missing>");
1066 };
1067 let label = match node {
1068 CompiledExpressionNode::Constant(value) => format!("const {value}"),
1069 CompiledExpressionNode::Amplitude { index, name } => {
1070 format!("{name}(id={index})")
1071 }
1072 CompiledExpressionNode::Unary { op, .. }
1073 | CompiledExpressionNode::Binary { op, .. } => op.clone(),
1074 };
1075 format!("#{index} {label}")
1076 }
1077
1078 fn write_tree(
1080 &self,
1081 index: usize,
1082 f: &mut std::fmt::Formatter<'_>,
1083 parent_prefix: &str,
1084 immediate_prefix: &str,
1085 parent_suffix: &str,
1086 expanded: &mut [bool],
1087 ) -> std::fmt::Result {
1088 let already_expanded = expanded.get(index).copied().unwrap_or(false);
1089 if let Some(slot) = expanded.get_mut(index) {
1090 *slot = true;
1091 }
1092 let ref_suffix = if already_expanded { " (ref)" } else { "" };
1093 writeln!(
1094 f,
1095 "{}{}{}{}",
1096 parent_prefix,
1097 immediate_prefix,
1098 self.node_label(index),
1099 ref_suffix
1100 )?;
1101 if already_expanded {
1102 return Ok(());
1103 }
1104 let Some(node) = self.nodes.get(index) else {
1105 return Ok(());
1106 };
1107 match node {
1108 CompiledExpressionNode::Constant(_) | CompiledExpressionNode::Amplitude { .. } => {}
1109 CompiledExpressionNode::Unary { input, .. } => {
1110 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1111 self.write_tree(*input, f, &child_prefix, "└─ ", " ", expanded)?;
1112 }
1113 CompiledExpressionNode::Binary { left, right, .. } => {
1114 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1115 self.write_tree(*left, f, &child_prefix, "├─ ", "│ ", expanded)?;
1116 self.write_tree(*right, f, &child_prefix, "└─ ", " ", expanded)?;
1117 }
1118 }
1119 Ok(())
1120 }
1121}
1122
1123impl Display for CompiledExpression {
1124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1125 if self.nodes.is_empty() {
1126 return writeln!(f, "<empty>");
1127 }
1128 let mut expanded = vec![false; self.nodes.len()];
1129 self.write_tree(self.root, f, "", "", "", &mut expanded)
1130 }
1131}
1132
1133fn compiled_unary_op_label(op: ir::IrUnaryOp) -> String {
1134 match op {
1135 ir::IrUnaryOp::Neg => "-".to_string(),
1136 ir::IrUnaryOp::Real => "Re".to_string(),
1137 ir::IrUnaryOp::Imag => "Im".to_string(),
1138 ir::IrUnaryOp::Conj => "*".to_string(),
1139 ir::IrUnaryOp::NormSqr => "NormSqr".to_string(),
1140 ir::IrUnaryOp::Sqrt => "Sqrt".to_string(),
1141 ir::IrUnaryOp::PowI(power) => format!("PowI({power})"),
1142 ir::IrUnaryOp::PowF(bits) => format!("PowF({})", f64::from_bits(bits)),
1143 ir::IrUnaryOp::Exp => "Exp".to_string(),
1144 ir::IrUnaryOp::Sin => "Sin".to_string(),
1145 ir::IrUnaryOp::Cos => "Cos".to_string(),
1146 ir::IrUnaryOp::Log => "Log".to_string(),
1147 ir::IrUnaryOp::Cis => "Cis".to_string(),
1148 }
1149}
1150
1151fn compiled_binary_op_label(op: ir::IrBinaryOp) -> String {
1152 match op {
1153 ir::IrBinaryOp::Add => "+".to_string(),
1154 ir::IrBinaryOp::Sub => "-".to_string(),
1155 ir::IrBinaryOp::Mul => "×".to_string(),
1156 ir::IrBinaryOp::Div => "÷".to_string(),
1157 ir::IrBinaryOp::Pow => "Pow".to_string(),
1158 }
1159}
1160#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1161pub enum ExpressionSpecializationOrigin {
1163 InitialLoad,
1165 CacheMissRebuild,
1167 CacheHitRestore,
1169}
1170#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1171pub struct ExpressionSpecializationStatus {
1173 pub origin: ExpressionSpecializationOrigin,
1175}
1176#[derive(Clone, Debug, PartialEq, Eq)]
1177pub struct ExpressionRuntimeDiagnostics {
1179 pub ir_planning_enabled: bool,
1181 pub lowered_value_program_present: bool,
1183 pub lowered_gradient_program_present: bool,
1185 pub lowered_value_gradient_program_present: bool,
1187 pub cached_parameter_factor_count: usize,
1189 pub lowered_cached_parameter_factor_count: usize,
1191 pub residual_runtime_present: bool,
1193 pub specialization_cache_entries: usize,
1195 pub lowered_artifact_cache_entries: usize,
1197 pub specialization_status: Option<ExpressionSpecializationStatus>,
1199}
1200#[derive(Clone)]
1201struct ExpressionIrPlanningState {
1208 cached_integrals: Arc<RwLock<Option<Arc<CachedIntegralCacheState>>>>,
1209 specialization_cache:
1210 Arc<RwLock<HashMap<CachedIntegralCacheKey, ExpressionSpecializationState>>>,
1211 specialization_metrics: Arc<RwLock<ExpressionSpecializationMetrics>>,
1212 lowered_artifact_cache: Arc<RwLock<HashMap<Vec<bool>, Arc<LoweredArtifactCacheState>>>>,
1213 active_lowered_artifacts: Arc<RwLock<Option<Arc<LoweredArtifactCacheState>>>>,
1214 specialization_status: Arc<RwLock<Option<ExpressionSpecializationStatus>>>,
1215 compile_metrics: Arc<RwLock<ExpressionCompileMetrics>>,
1216}
1217#[allow(missing_docs)]
1219#[derive(Clone)]
1220pub struct Evaluator {
1221 pub amplitudes: Vec<Box<dyn Amplitude>>,
1222 amplitude_use_sites: Vec<AmplitudeUseSite>,
1223 pub resources: Arc<RwLock<Resources>>,
1224 pub dataset: Arc<Dataset>,
1225 pub expression: ExpressionNode,
1226 ir_planning: ExpressionIrPlanningState,
1227 registry: ExpressionRegistry,
1228}
1229
1230#[allow(missing_docs)]
1231impl Evaluator {
1232 pub fn expression_specialization_metrics(&self) -> ExpressionSpecializationMetrics {
1234 *self.ir_planning.specialization_metrics.read()
1235 }
1236 pub fn reset_expression_specialization_metrics(&self) {
1238 *self.ir_planning.specialization_metrics.write() =
1239 ExpressionSpecializationMetrics::default();
1240 }
1241 pub fn expression_compile_metrics(&self) -> ExpressionCompileMetrics {
1243 *self.ir_planning.compile_metrics.read()
1244 }
1245 pub fn expression_runtime_diagnostics(&self) -> ExpressionRuntimeDiagnostics {
1247 let active_artifacts = self.active_lowered_artifacts();
1248 let cached_parameter_factor_count = self
1249 .ir_planning
1250 .cached_integrals
1251 .read()
1252 .as_ref()
1253 .map(|state| state.values.len())
1254 .unwrap_or(0);
1255 let lowered_cached_parameter_factor_count = active_artifacts
1256 .as_ref()
1257 .map(|artifacts| {
1258 artifacts
1259 .lowered_parameter_factors
1260 .iter()
1261 .filter(|factor| factor.is_some())
1262 .count()
1263 })
1264 .unwrap_or(0);
1265 let residual_runtime_present = active_artifacts
1266 .as_ref()
1267 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
1268 .is_some();
1269 ExpressionRuntimeDiagnostics {
1270 ir_planning_enabled: true,
1271 lowered_value_program_present: true,
1272 lowered_gradient_program_present: true,
1273 lowered_value_gradient_program_present: true,
1274 cached_parameter_factor_count,
1275 lowered_cached_parameter_factor_count,
1276 residual_runtime_present,
1277 specialization_cache_entries: self.ir_planning.specialization_cache.read().len(),
1278 lowered_artifact_cache_entries: self.ir_planning.lowered_artifact_cache.read().len(),
1279 specialization_status: *self.ir_planning.specialization_status.read(),
1280 }
1281 }
1282 pub fn reset_expression_compile_metrics(&self) {
1284 let mut metrics = self.ir_planning.compile_metrics.write();
1285 metrics.specialization_cache_hits = 0;
1286 metrics.specialization_cache_misses = 0;
1287 metrics.specialization_ir_compile_nanos = 0;
1288 metrics.specialization_cached_integrals_nanos = 0;
1289 metrics.specialization_lowering_nanos = 0;
1290 metrics.specialization_lowering_cache_hits = 0;
1291 metrics.specialization_lowering_cache_misses = 0;
1292 metrics.specialization_cache_restore_nanos = 0;
1293 }
1294 #[cfg(test)]
1295 fn expression_ir(&self) -> ir::ExpressionIR {
1296 self.ir_planning
1297 .cached_integrals
1298 .read()
1299 .as_ref()
1300 .map(|state| state.expression_ir.clone())
1301 .expect("cached integral state should exist for evaluator IR access")
1302 }
1303 fn lowered_runtime(&self) -> lowered::LoweredExpressionRuntime {
1304 self.active_lowered_artifacts()
1305 .expect("active lowered artifacts should exist for the current specialization")
1306 .lowered_runtime
1307 .clone()
1308 }
1309 fn active_lowered_artifacts(&self) -> Option<Arc<LoweredArtifactCacheState>> {
1310 self.ir_planning.active_lowered_artifacts.read().clone()
1311 }
1312 fn lowered_runtime_slot_count(&self) -> usize {
1313 let runtime = self.lowered_runtime();
1314 [
1315 runtime.value_program().scratch_slots(),
1316 runtime.gradient_program().scratch_slots(),
1317 runtime.value_gradient_program().scratch_slots(),
1318 ]
1319 .into_iter()
1320 .max()
1321 .unwrap_or(0)
1322 }
1323 fn lowered_value_runtime_slot_count(&self) -> usize {
1324 self.lowered_runtime().value_program().scratch_slots()
1325 }
1326
1327 #[doc(hidden)]
1328 pub fn expression_value_program_snapshot(&self) -> ExpressionValueProgramSnapshot {
1329 ExpressionValueProgramSnapshot {
1330 lowered_program: self.lowered_runtime().value_program().clone(),
1331 }
1332 }
1333
1334 #[doc(hidden)]
1335 pub fn expression_value_program_snapshot_for_active_mask(
1336 &self,
1337 active_mask: &[bool],
1338 ) -> LadduResult<ExpressionValueProgramSnapshot> {
1339 let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1340 let lowered_program =
1341 lowered::LoweredProgram::from_ir_value_only(&expression_ir).map_err(|error| {
1342 LadduError::Custom(format!(
1343 "Failed to lower value-only active-mask runtime: {error:?}"
1344 ))
1345 })?;
1346 Ok(ExpressionValueProgramSnapshot { lowered_program })
1347 }
1348
1349 #[doc(hidden)]
1350 pub fn expression_value_program_snapshot_slot_count(
1351 &self,
1352 snapshot: &ExpressionValueProgramSnapshot,
1353 ) -> usize {
1354 let _ = self;
1355 snapshot.lowered_program.scratch_slots()
1356 }
1357
1358 pub fn compiled_expression(&self) -> CompiledExpression {
1361 let expression_ir = self.compile_expression_ir_for_active_mask(&self.active_mask());
1362 CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
1363 }
1364
1365 pub fn expression(&self) -> Expression {
1367 Expression {
1368 tree: self.expression.clone(),
1369 registry: self.registry.clone(),
1370 }
1371 }
1372 fn lowered_gradient_runtime_slot_count(&self) -> usize {
1373 self.lowered_runtime().gradient_program().scratch_slots()
1374 }
1375 fn lowered_value_gradient_runtime_slot_count(&self) -> usize {
1376 self.lowered_runtime()
1377 .value_gradient_program()
1378 .scratch_slots()
1379 }
1380
1381 fn expression_value_slot_count(&self) -> usize {
1382 self.lowered_value_runtime_slot_count()
1383 }
1384 fn expression_gradient_slot_count(&self) -> usize {
1385 self.lowered_gradient_runtime_slot_count()
1386 }
1387 fn expression_value_gradient_slot_count(&self) -> usize {
1388 self.lowered_value_gradient_runtime_slot_count()
1389 }
1390
1391 #[doc(hidden)]
1392 pub fn expression_value_gradient_slot_count_public(&self) -> usize {
1393 self.expression_value_gradient_slot_count()
1394 }
1395 #[cfg(test)]
1396 fn specialization_cache_len(&self) -> usize {
1397 self.ir_planning.specialization_cache.read().len()
1398 }
1399 #[cfg(test)]
1400 fn lowered_artifact_cache_len(&self) -> usize {
1401 self.ir_planning.lowered_artifact_cache.read().len()
1402 }
1403 fn install_expression_specialization(&self, specialization: &ExpressionSpecializationState) {
1404 debug_assert!(Self::lowered_artifact_signature_matches(
1405 &specialization.lowered_artifacts,
1406 &specialization.cached_integrals.values,
1407 ));
1408 *self.ir_planning.cached_integrals.write() = Some(specialization.cached_integrals.clone());
1409 *self.ir_planning.active_lowered_artifacts.write() =
1410 Some(specialization.lowered_artifacts.clone());
1411 debug_assert_eq!(
1412 self.active_lowered_artifacts()
1413 .as_ref()
1414 .map(|artifacts| Arc::ptr_eq(artifacts, &specialization.lowered_artifacts)),
1415 Some(true)
1416 );
1417 debug_assert_eq!(
1418 self.lowered_runtime().value_program().scratch_slots(),
1419 specialization
1420 .lowered_artifacts
1421 .lowered_runtime
1422 .value_program()
1423 .scratch_slots()
1424 );
1425 }
1426 fn lower_expression_runtime_artifacts(
1427 expression_ir: &ir::ExpressionIR,
1428 values: &[PrecomputedCachedIntegral],
1429 ) -> LadduResult<LoweredArtifactCacheState> {
1430 let parameter_node_indices = values
1431 .iter()
1432 .map(|value| value.parameter_node_index)
1433 .collect();
1434 let mul_node_indices = values.iter().map(|value| value.mul_node_index).collect();
1435 let lowered_parameter_factors = Self::lower_cached_parameter_factors(expression_ir);
1436 let residual_runtime = Self::lower_residual_runtime(expression_ir, values);
1437 let lowered_runtime = lowered::LoweredExpressionRuntime::from_ir_value_gradient(
1438 expression_ir,
1439 )
1440 .map_err(|error| {
1441 LadduError::Custom(format!(
1442 "Failed to lower expression runtime for specialized IR: {error:?}"
1443 ))
1444 })?;
1445 Ok(LoweredArtifactCacheState {
1446 parameter_node_indices,
1447 mul_node_indices,
1448 lowered_parameter_factors,
1449 residual_runtime,
1450 lowered_runtime,
1451 })
1452 }
1453 fn lowered_artifact_signature_matches(
1454 artifacts: &LoweredArtifactCacheState,
1455 values: &[PrecomputedCachedIntegral],
1456 ) -> bool {
1457 artifacts.parameter_node_indices.len() == values.len()
1458 && artifacts.mul_node_indices.len() == values.len()
1459 && artifacts
1460 .parameter_node_indices
1461 .iter()
1462 .copied()
1463 .eq(values.iter().map(|value| value.parameter_node_index))
1464 && artifacts
1465 .mul_node_indices
1466 .iter()
1467 .copied()
1468 .eq(values.iter().map(|value| value.mul_node_index))
1469 }
1470 fn build_expression_specialization(
1471 &self,
1472 resources: &Resources,
1473 key: CachedIntegralCacheKey,
1474 ) -> LadduResult<ExpressionSpecializationState> {
1475 let ir_compile_start = Instant::now();
1476 let expression_ir = self.compile_expression_ir_for_active_mask(&resources.active);
1477 let ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
1478 let cached_integrals_start = Instant::now();
1479 let values = Self::precompute_cached_integrals_at_load(
1480 &expression_ir,
1481 &self.amplitudes,
1482 &self.amplitude_use_sites,
1483 resources,
1484 &self.dataset,
1485 self.resources.read().n_free_parameters(),
1486 )?;
1487 let cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
1488 let execution_sets = expression_ir.normalization_execution_sets().clone();
1489 let active_mask_key = resources.active.clone();
1490 let cached_lowered_artifacts = {
1491 let lowered_artifact_cache = self.ir_planning.lowered_artifact_cache.read();
1492 lowered_artifact_cache
1493 .get(&active_mask_key)
1494 .cloned()
1495 .filter(|artifacts| Self::lowered_artifact_signature_matches(artifacts, &values))
1496 };
1497 let lowered_artifacts = if let Some(artifacts) = cached_lowered_artifacts {
1498 self.ir_planning
1499 .compile_metrics
1500 .write()
1501 .specialization_lowering_cache_hits += 1;
1502 artifacts
1503 } else {
1504 let lowering_start = Instant::now();
1505 let artifacts = Arc::new(
1506 Self::lower_expression_runtime_artifacts(&expression_ir, &values)
1507 .expect("specialized lowered runtime should build"),
1508 );
1509 let lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
1510 self.ir_planning
1511 .lowered_artifact_cache
1512 .write()
1513 .insert(active_mask_key, artifacts.clone());
1514 let mut compile_metrics = self.ir_planning.compile_metrics.write();
1515 compile_metrics.specialization_lowering_cache_misses += 1;
1516 compile_metrics.specialization_lowering_nanos += lowering_nanos;
1517 artifacts
1518 };
1519 let mut compile_metrics = self.ir_planning.compile_metrics.write();
1520 compile_metrics.specialization_cache_misses += 1;
1521 compile_metrics.specialization_ir_compile_nanos += ir_compile_nanos;
1522 compile_metrics.specialization_cached_integrals_nanos += cached_integrals_nanos;
1523 Ok(ExpressionSpecializationState {
1524 cached_integrals: Arc::new(CachedIntegralCacheState {
1525 key,
1526 expression_ir,
1527 values,
1528 execution_sets,
1529 }),
1530 lowered_artifacts,
1531 })
1532 }
1533 fn ensure_expression_specialization(
1534 &self,
1535 resources: &Resources,
1536 ) -> LadduResult<ExpressionSpecializationState> {
1537 let key = Self::cached_integral_cache_key(resources.active.clone(), &self.dataset);
1538 if let Some(state) = self.ir_planning.cached_integrals.read().as_ref() {
1539 if state.key == key {
1540 return Ok(ExpressionSpecializationState {
1541 cached_integrals: state.clone(),
1542 lowered_artifacts: self
1543 .active_lowered_artifacts()
1544 .expect("active lowered artifacts should exist for cached specialization"),
1545 });
1546 }
1547 }
1548 let cached_specialization = {
1549 let specialization_cache = self.ir_planning.specialization_cache.read();
1550 specialization_cache.get(&key).cloned()
1551 };
1552 if let Some(specialization) = cached_specialization {
1553 let restore_start = Instant::now();
1554 self.ir_planning.specialization_metrics.write().cache_hits += 1;
1555 self.install_expression_specialization(&specialization);
1556 *self.ir_planning.specialization_status.write() =
1557 Some(ExpressionSpecializationStatus {
1558 origin: ExpressionSpecializationOrigin::CacheHitRestore,
1559 });
1560 let restore_nanos = restore_start.elapsed().as_nanos() as u64;
1561 let mut compile_metrics = self.ir_planning.compile_metrics.write();
1562 compile_metrics.specialization_cache_hits += 1;
1563 compile_metrics.specialization_cache_restore_nanos += restore_nanos;
1564 return Ok(specialization);
1565 }
1566 let specialization = self.build_expression_specialization(resources, key.clone())?;
1567 self.ir_planning.specialization_metrics.write().cache_misses += 1;
1568 self.ir_planning
1569 .specialization_cache
1570 .write()
1571 .insert(key, specialization.clone());
1572 self.install_expression_specialization(&specialization);
1573 let origin = if self.ir_planning.specialization_cache.read().len() == 1 {
1574 ExpressionSpecializationOrigin::InitialLoad
1575 } else {
1576 ExpressionSpecializationOrigin::CacheMissRebuild
1577 };
1578 *self.ir_planning.specialization_status.write() =
1579 Some(ExpressionSpecializationStatus { origin });
1580 Ok(specialization)
1581 }
1582 fn rebuild_runtime_specializations(&self, resources: &Resources) {
1583 let _ = self.ensure_expression_specialization(resources);
1584 }
1585 fn refresh_runtime_specializations(&self) {
1586 let resources = self.resources.read();
1587 self.rebuild_runtime_specializations(&resources);
1588 }
1589 fn cached_integral_cache_key(
1590 active_mask: Vec<bool>,
1591 dataset: &Dataset,
1592 ) -> CachedIntegralCacheKey {
1593 let (weights_ptr, weights_local_len) = dataset.local_weight_cache_key();
1594 CachedIntegralCacheKey {
1595 active_mask,
1596 n_events_local: dataset.n_events_local(),
1597 weights_local_len,
1598 weighted_sum_bits: dataset.n_events_weighted_local().to_bits(),
1599 weights_ptr,
1600 }
1601 }
1602 fn precompute_cached_integrals_at_load(
1603 expression_ir: &ir::ExpressionIR,
1604 amplitudes: &[Box<dyn Amplitude>],
1605 amplitude_use_sites: &[AmplitudeUseSite],
1606 resources: &Resources,
1607 dataset: &Dataset,
1608 n_free_parameters: usize,
1609 ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
1610 let descriptors = expression_ir.cached_integral_descriptors();
1611 if descriptors.is_empty() {
1612 return Ok(Vec::new());
1613 }
1614 let execution_sets = expression_ir.normalization_execution_sets();
1615 let seed_parameters = vec![0.0; n_free_parameters];
1616 let parameters = resources.parameter_map.assemble(&seed_parameters)?;
1617 let mut amplitude_values = vec![Complex64::ZERO; amplitude_use_sites.len()];
1618 let mut compute_values = vec![Complex64::ZERO; amplitudes.len()];
1619 let mut value_slots = vec![Complex64::ZERO; expression_ir.node_count()];
1620 let active_set = resources.active_indices();
1621 let cache_active_indices = execution_sets
1622 .cached_cache_amplitudes
1623 .iter()
1624 .copied()
1625 .filter(|index| active_set.binary_search(index).is_ok())
1626 .collect::<Vec<_>>();
1627 let mut weighted_cache_sums = vec![Complex64::ZERO; descriptors.len()];
1628 for (cache, event) in resources.caches.iter().zip(dataset.weights_local().iter()) {
1629 amplitude_values.fill(Complex64::ZERO);
1630 compute_values.fill(Complex64::ZERO);
1631 let mut computed = vec![false; amplitudes.len()];
1632 for &use_site_idx in &cache_active_indices {
1633 let amp_idx = amplitude_use_sites[use_site_idx].amplitude_index;
1634 if !computed[amp_idx] {
1635 compute_values[amp_idx] = amplitudes[amp_idx].compute(¶meters, cache);
1636 computed[amp_idx] = true;
1637 }
1638 amplitude_values[use_site_idx] = compute_values[amp_idx];
1639 }
1640 expression_ir.evaluate_into(&litude_values, &mut value_slots);
1641 let weight = *event;
1642 for (descriptor_index, descriptor) in descriptors.iter().enumerate() {
1643 weighted_cache_sums[descriptor_index] +=
1644 value_slots[descriptor.cache_node_index] * weight;
1645 }
1646 }
1647 Ok(descriptors
1648 .iter()
1649 .zip(weighted_cache_sums)
1650 .map(
1651 |(descriptor, weighted_cache_sum)| PrecomputedCachedIntegral {
1652 mul_node_index: descriptor.mul_node_index,
1653 parameter_node_index: descriptor.parameter_node_index,
1654 cache_node_index: descriptor.cache_node_index,
1655 coefficient: descriptor.coefficient,
1656 weighted_cache_sum,
1657 },
1658 )
1659 .collect())
1660 }
1661 fn lower_cached_parameter_factors(
1662 expression_ir: &ir::ExpressionIR,
1663 ) -> Vec<Option<lowered::LoweredFactorRuntime>> {
1664 expression_ir
1665 .cached_integral_descriptors()
1666 .iter()
1667 .map(|descriptor| {
1668 lowered::LoweredFactorRuntime::from_ir_root_value_gradient(
1669 expression_ir,
1670 descriptor.parameter_node_index,
1671 )
1672 .ok()
1673 })
1674 .collect()
1675 }
1676 fn lower_residual_runtime(
1677 expression_ir: &ir::ExpressionIR,
1678 descriptors: &[PrecomputedCachedIntegral],
1679 ) -> Option<lowered::LoweredExpressionRuntime> {
1680 let mut zeroed_nodes = vec![false; expression_ir.node_count()];
1681 for descriptor in descriptors {
1682 if descriptor.mul_node_index < zeroed_nodes.len() {
1683 zeroed_nodes[descriptor.mul_node_index] = true;
1684 }
1685 }
1686 lowered::LoweredExpressionRuntime::from_ir_zeroed_value_gradient(
1687 expression_ir,
1688 &zeroed_nodes,
1689 )
1690 .ok()
1691 }
1692
1693 pub fn amplitude_value_slot_count(&self) -> usize {
1700 self.amplitude_use_sites.len()
1701 }
1702
1703 #[inline]
1712 pub fn fill_amplitude_values(
1713 &self,
1714 amplitude_values: &mut [Complex64],
1715 active_indices: &[usize],
1716 parameters: &Parameters,
1717 cache: &Cache,
1718 ) {
1719 amplitude_values.fill(Complex64::ZERO);
1720 let mut compute_values = vec![Complex64::ZERO; self.amplitudes.len()];
1721 let mut computed = vec![false; self.amplitudes.len()];
1722 for &use_site_idx in active_indices {
1723 let amp_idx = self.amplitude_use_sites[use_site_idx].amplitude_index;
1724 if !computed[amp_idx] {
1725 compute_values[amp_idx] = self.amplitudes[amp_idx].compute(parameters, cache);
1726 computed[amp_idx] = true;
1727 }
1728 amplitude_values[use_site_idx] = compute_values[amp_idx];
1729 }
1730 }
1731
1732 #[inline]
1733 fn fill_amplitude_gradients(
1734 &self,
1735 gradient_values: &mut [DVector<Complex64>],
1736 active_mask: &[bool],
1737 parameters: &Parameters,
1738 cache: &Cache,
1739 ) {
1740 let mut compute_gradients = vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1741 let mut computed = vec![false; self.amplitudes.len()];
1742 for ((use_site, active), grad) in self
1743 .amplitude_use_sites
1744 .iter()
1745 .zip(active_mask.iter())
1746 .zip(gradient_values.iter_mut())
1747 {
1748 grad.fill(Complex64::ZERO);
1749 if *active {
1750 let amp_idx = use_site.amplitude_index;
1751 if !computed[amp_idx] {
1752 self.amplitudes[amp_idx].compute_gradient(
1753 parameters,
1754 cache,
1755 &mut compute_gradients[amp_idx],
1756 );
1757 computed[amp_idx] = true;
1758 }
1759 grad.copy_from(&compute_gradients[amp_idx]);
1760 }
1761 }
1762 }
1763
1764 #[inline]
1772 pub fn fill_amplitude_values_and_gradients(
1773 &self,
1774 amplitude_values: &mut [Complex64],
1775 gradient_values: &mut [DVector<Complex64>],
1776 active_indices: &[usize],
1777 active_mask: &[bool],
1778 parameters: &Parameters,
1779 cache: &Cache,
1780 ) {
1781 self.fill_amplitude_values(amplitude_values, active_indices, parameters, cache);
1782 self.fill_amplitude_gradients(gradient_values, active_mask, parameters, cache);
1783 }
1784
1785 #[cfg(feature = "execution-context-prototype")]
1786 #[inline]
1787 fn evaluate_cache_gradient_with_scratch(
1788 &self,
1789 amplitude_values: &mut [Complex64],
1790 gradient_values: &mut [DVector<Complex64>],
1791 value_slots: &mut [Complex64],
1792 gradient_slots: &mut [DVector<Complex64>],
1793 active_indices: &[usize],
1794 active_mask: &[bool],
1795 parameters: &Parameters,
1796 cache: &Cache,
1797 ) -> DVector<Complex64> {
1798 self.fill_amplitude_values_and_gradients(
1799 amplitude_values,
1800 gradient_values,
1801 active_indices,
1802 active_mask,
1803 parameters,
1804 cache,
1805 );
1806 self.evaluate_expression_gradient_with_scratch(
1807 amplitude_values,
1808 gradient_values,
1809 value_slots,
1810 gradient_slots,
1811 )
1812 }
1813
1814 #[cfg(feature = "execution-context-prototype")]
1815 #[allow(dead_code)]
1816 #[inline]
1817 fn evaluate_cache_value_gradient_with_scratch(
1818 &self,
1819 amplitude_values: &mut [Complex64],
1820 gradient_values: &mut [DVector<Complex64>],
1821 value_slots: &mut [Complex64],
1822 gradient_slots: &mut [DVector<Complex64>],
1823 active_indices: &[usize],
1824 active_mask: &[bool],
1825 parameters: &Parameters,
1826 cache: &Cache,
1827 ) -> (Complex64, DVector<Complex64>) {
1828 self.fill_amplitude_values_and_gradients(
1829 amplitude_values,
1830 gradient_values,
1831 active_indices,
1832 active_mask,
1833 parameters,
1834 cache,
1835 );
1836 self.evaluate_expression_value_gradient_with_scratch(
1837 amplitude_values,
1838 gradient_values,
1839 value_slots,
1840 gradient_slots,
1841 )
1842 }
1843
1844 pub fn expression_slot_count(&self) -> usize {
1845 self.lowered_runtime_slot_count()
1846 }
1847 fn compile_expression_ir_for_active_mask(&self, active_mask: &[bool]) -> ir::ExpressionIR {
1848 let amplitude_dependencies = self
1849 .amplitude_use_sites
1850 .iter()
1851 .map(|use_site| {
1852 ir::DependenceClass::from(
1853 self.amplitudes[use_site.amplitude_index].dependence_hint(),
1854 )
1855 })
1856 .collect::<Vec<_>>();
1857 let amplitude_realness = self
1858 .amplitude_use_sites
1859 .iter()
1860 .map(|use_site| self.amplitudes[use_site.amplitude_index].real_valued_hint())
1861 .collect::<Vec<_>>();
1862 ir::compile_expression_ir_with_real_hints(
1863 &self.expression,
1864 active_mask,
1865 &litude_dependencies,
1866 &litude_realness,
1867 )
1868 }
1869 fn lower_expression_runtime_for_active_mask(
1870 &self,
1871 active_mask: &[bool],
1872 ) -> LadduResult<lowered::LoweredExpressionRuntime> {
1873 let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1874 lowered::LoweredExpressionRuntime::from_ir_value_gradient(&expression_ir).map_err(|error| {
1875 LadduError::Custom(format!(
1876 "Failed to lower active-mask runtime specialization: {error:?}"
1877 ))
1878 })
1879 }
1880 fn ensure_cached_integral_cache_state(
1881 &self,
1882 resources: &Resources,
1883 ) -> LadduResult<Arc<CachedIntegralCacheState>> {
1884 Ok(self
1885 .ensure_expression_specialization(resources)?
1886 .cached_integrals)
1887 }
1888
1889 fn evaluate_expression_runtime_value_with_scratch(
1890 &self,
1891 amplitude_values: &[Complex64],
1892 scratch: &mut [Complex64],
1893 ) -> Complex64 {
1894 let lowered_runtime = self.lowered_runtime();
1895 lowered_runtime
1896 .value_program()
1897 .evaluate_into(amplitude_values, scratch)
1898 }
1899
1900 #[doc(hidden)]
1901 pub fn evaluate_expression_value_with_program_snapshot(
1902 &self,
1903 program_snapshot: &ExpressionValueProgramSnapshot,
1904 amplitude_values: &[Complex64],
1905 scratch: &mut [Complex64],
1906 ) -> Complex64 {
1907 program_snapshot
1908 .lowered_program
1909 .evaluate_into(amplitude_values, scratch)
1910 }
1911
1912 fn evaluate_expression_runtime_gradient_with_scratch(
1913 &self,
1914 amplitude_values: &[Complex64],
1915 gradient_values: &[DVector<Complex64>],
1916 value_scratch: &mut [Complex64],
1917 gradient_scratch: &mut [DVector<Complex64>],
1918 ) -> DVector<Complex64> {
1919 let lowered_runtime = self.lowered_runtime();
1920 lowered_runtime.gradient_program().evaluate_gradient_into(
1921 amplitude_values,
1922 gradient_values,
1923 value_scratch,
1924 gradient_scratch,
1925 )
1926 }
1927
1928 fn evaluate_expression_runtime_value_gradient_with_scratch(
1929 &self,
1930 amplitude_values: &[Complex64],
1931 gradient_values: &[DVector<Complex64>],
1932 value_scratch: &mut [Complex64],
1933 gradient_scratch: &mut [DVector<Complex64>],
1934 ) -> (Complex64, DVector<Complex64>) {
1935 let lowered_runtime = self.lowered_runtime();
1936 lowered_runtime
1937 .value_gradient_program()
1938 .evaluate_value_gradient_into(
1939 amplitude_values,
1940 gradient_values,
1941 value_scratch,
1942 gradient_scratch,
1943 )
1944 }
1945
1946 fn evaluate_expression_runtime_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
1947 let lowered_runtime = self.lowered_runtime();
1948 let program = lowered_runtime.value_program();
1949 let mut scratch = vec![Complex64::ZERO; program.scratch_slots()];
1950 program.evaluate_into(amplitude_values, &mut scratch)
1951 }
1952
1953 fn evaluate_expression_runtime_gradient(
1954 &self,
1955 amplitude_values: &[Complex64],
1956 gradient_values: &[DVector<Complex64>],
1957 ) -> DVector<Complex64> {
1958 let lowered_runtime = self.lowered_runtime();
1959 let program = lowered_runtime.gradient_program();
1960 let mut value_scratch = vec![Complex64::ZERO; program.scratch_slots()];
1961 let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
1962 let mut gradient_scratch = vec![Complex64::ZERO; program.scratch_slots() * grad_dim];
1963 program.evaluate_gradient_into_flat(
1964 amplitude_values,
1965 gradient_values,
1966 &mut value_scratch,
1967 &mut gradient_scratch,
1968 grad_dim,
1969 )
1970 }
1971 pub fn expression_root_dependence(&self) -> LadduResult<ExpressionDependence> {
1973 let resources = self.resources.read();
1974 Ok(self
1975 .ensure_cached_integral_cache_state(&resources)?
1976 .expression_ir
1977 .root_dependence()
1978 .into())
1979 }
1980 pub fn expression_node_dependence_annotations(&self) -> LadduResult<Vec<ExpressionDependence>> {
1982 let resources = self.resources.read();
1983 Ok(self
1984 .ensure_cached_integral_cache_state(&resources)?
1985 .expression_ir
1986 .node_dependence_annotations()
1987 .iter()
1988 .copied()
1989 .map(Into::into)
1990 .collect())
1991 }
1992 pub fn expression_dependence_warnings(&self) -> LadduResult<Vec<String>> {
1994 let resources = self.resources.read();
1995 Ok(self
1996 .ensure_cached_integral_cache_state(&resources)?
1997 .expression_ir
1998 .dependence_warnings()
1999 .to_vec())
2000 }
2001 pub fn expression_normalization_plan_explain(&self) -> LadduResult<NormalizationPlanExplain> {
2003 let resources = self.resources.read();
2004 Ok(self
2005 .ensure_cached_integral_cache_state(&resources)?
2006 .expression_ir
2007 .normalization_plan_explain()
2008 .into())
2009 }
2010 pub fn expression_normalization_execution_sets(
2012 &self,
2013 ) -> LadduResult<NormalizationExecutionSetsExplain> {
2014 let resources = self.resources.read();
2015 Ok(self
2016 .ensure_cached_integral_cache_state(&resources)?
2017 .execution_sets
2018 .clone()
2019 .into())
2020 }
2021 pub fn expression_precomputed_cached_integrals(
2023 &self,
2024 ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
2025 let resources = self.resources.read();
2026 Ok(self
2027 .ensure_cached_integral_cache_state(&resources)?
2028 .values
2029 .clone())
2030 }
2031 pub fn expression_precomputed_cached_integral_gradient_terms(
2036 &self,
2037 parameters: &[f64],
2038 ) -> LadduResult<Vec<PrecomputedCachedIntegralGradientTerm>> {
2039 let resources = self.resources.read();
2040 let state = self.ensure_cached_integral_cache_state(&resources)?;
2041 if state.values.is_empty() {
2042 return Ok(Vec::new());
2043 }
2044
2045 let Some(cache) = resources.caches.first() else {
2046 return Ok(state
2047 .values
2048 .iter()
2049 .map(|descriptor| PrecomputedCachedIntegralGradientTerm {
2050 mul_node_index: descriptor.mul_node_index,
2051 parameter_node_index: descriptor.parameter_node_index,
2052 cache_node_index: descriptor.cache_node_index,
2053 coefficient: descriptor.coefficient,
2054 weighted_gradient: DVector::zeros(parameters.len()),
2055 })
2056 .collect());
2057 };
2058
2059 let parameter_values = resources.parameter_map.assemble(parameters)?;
2060 let mut amplitude_values = vec![Complex64::ZERO; self.amplitude_use_sites.len()];
2061 self.fill_amplitude_values(
2062 &mut amplitude_values,
2063 resources.active_indices(),
2064 ¶meter_values,
2065 cache,
2066 );
2067 let mut amplitude_gradients = (0..self.amplitude_use_sites.len())
2068 .map(|_| DVector::zeros(parameters.len()))
2069 .collect::<Vec<_>>();
2070 self.fill_amplitude_gradients(
2071 &mut amplitude_gradients,
2072 &resources.active,
2073 ¶meter_values,
2074 cache,
2075 );
2076 let lowered_artifacts = self.active_lowered_artifacts();
2077 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2078 let mut gradient_slots = (0..state.expression_ir.node_count())
2079 .map(|_| DVector::zeros(parameters.len()))
2080 .collect::<Vec<_>>();
2081 let max_lowered_slots = lowered_artifacts
2082 .as_ref()
2083 .map(|artifacts| {
2084 artifacts
2085 .lowered_parameter_factors
2086 .iter()
2087 .filter_map(|runtime| {
2088 runtime
2089 .as_ref()
2090 .and_then(|runtime| runtime.gradient_program())
2091 .map(|program| program.scratch_slots())
2092 })
2093 .max()
2094 .unwrap_or(0)
2095 })
2096 .unwrap_or(0);
2097 let mut lowered_value_slots = vec![Complex64::ZERO; max_lowered_slots];
2098 let mut lowered_gradient_slots = vec![DVector::zeros(parameters.len()); max_lowered_slots];
2099 let use_lowered = lowered_artifacts.as_ref().is_some_and(|artifacts| {
2100 artifacts.lowered_parameter_factors.len() == state.values.len()
2101 && artifacts.lowered_parameter_factors.iter().all(|runtime| {
2102 runtime
2103 .as_ref()
2104 .and_then(|runtime| runtime.gradient_program())
2105 .is_some()
2106 })
2107 });
2108
2109 if !use_lowered {
2110 let _ = state.expression_ir.evaluate_gradient_into(
2111 &litude_values,
2112 &litude_gradients,
2113 &mut value_slots,
2114 &mut gradient_slots,
2115 );
2116 }
2117
2118 if use_lowered {
2119 let lowered_artifacts = lowered_artifacts.expect("lowered artifacts should exist");
2120 Ok(state
2121 .values
2122 .iter()
2123 .cloned()
2124 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2125 .map(|(descriptor, runtime)| {
2126 let parameter_gradient = runtime
2127 .as_ref()
2128 .and_then(|runtime| runtime.gradient_program())
2129 .map(|program| {
2130 program.evaluate_gradient_into(
2131 &litude_values,
2132 &litude_gradients,
2133 &mut lowered_value_slots[..program.scratch_slots()],
2134 &mut lowered_gradient_slots[..program.scratch_slots()],
2135 )
2136 })
2137 .unwrap_or_else(|| DVector::zeros(parameters.len()));
2138 let weighted_gradient = parameter_gradient.map(|value| {
2139 value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2140 });
2141 PrecomputedCachedIntegralGradientTerm {
2142 mul_node_index: descriptor.mul_node_index,
2143 parameter_node_index: descriptor.parameter_node_index,
2144 cache_node_index: descriptor.cache_node_index,
2145 coefficient: descriptor.coefficient,
2146 weighted_gradient,
2147 }
2148 })
2149 .collect())
2150 } else {
2151 Ok(state
2152 .values
2153 .iter()
2154 .map(|descriptor| {
2155 let parameter_gradient = gradient_slots
2156 .get(descriptor.parameter_node_index)
2157 .cloned()
2158 .unwrap_or_else(|| DVector::zeros(parameters.len()));
2159 let weighted_gradient = parameter_gradient.map(|value| {
2160 value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2161 });
2162 PrecomputedCachedIntegralGradientTerm {
2163 mul_node_index: descriptor.mul_node_index,
2164 parameter_node_index: descriptor.parameter_node_index,
2165 cache_node_index: descriptor.cache_node_index,
2166 coefficient: descriptor.coefficient,
2167 weighted_gradient,
2168 }
2169 })
2170 .collect())
2171 }
2172 }
2173 fn evaluate_cached_weighted_value_sum_ir(
2174 &self,
2175 state: &CachedIntegralCacheState,
2176 amplitude_values: &[Complex64],
2177 ) -> f64 {
2178 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2179 let _ = state
2180 .expression_ir
2181 .evaluate_into(amplitude_values, &mut value_slots);
2182 state
2183 .values
2184 .iter()
2185 .map(|descriptor| {
2186 let parameter_factor = value_slots[descriptor.parameter_node_index];
2187 (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2188 .re
2189 })
2190 .sum()
2191 }
2192 fn evaluate_cached_weighted_value_sum_lowered(
2193 &self,
2194 state: &CachedIntegralCacheState,
2195 lowered_artifacts: &LoweredArtifactCacheState,
2196 amplitude_values: &[Complex64],
2197 ) -> Option<f64> {
2198 let max_slots = lowered_artifacts
2199 .lowered_parameter_factors
2200 .iter()
2201 .filter_map(|runtime| {
2202 runtime
2203 .as_ref()
2204 .and_then(|runtime| runtime.value_program())
2205 .map(|program| program.scratch_slots())
2206 })
2207 .max()
2208 .unwrap_or(0);
2209 let mut value_slots = vec![Complex64::ZERO; max_slots];
2210 let mut total = 0.0;
2211 for (descriptor, runtime) in state
2212 .values
2213 .iter()
2214 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2215 {
2216 let parameter_factor = runtime
2217 .as_ref()
2218 .and_then(|runtime| runtime.value_program())
2219 .map(|program| {
2220 program.evaluate_into(
2221 amplitude_values,
2222 &mut value_slots[..program.scratch_slots()],
2223 )
2224 })?;
2225 total +=
2226 (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2227 .re;
2228 }
2229 Some(total)
2230 }
2231 fn evaluate_cached_weighted_gradient_sum_ir(
2232 &self,
2233 state: &CachedIntegralCacheState,
2234 amplitude_values: &[Complex64],
2235 amplitude_gradients: &[DVector<Complex64>],
2236 grad_dim: usize,
2237 ) -> DVector<f64> {
2238 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2239 let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2240 let _ = state.expression_ir.evaluate_gradient_into(
2241 amplitude_values,
2242 amplitude_gradients,
2243 &mut value_slots,
2244 &mut gradient_slots,
2245 );
2246 state
2247 .values
2248 .iter()
2249 .fold(DVector::zeros(grad_dim), |mut accum, descriptor| {
2250 let parameter_gradient = &gradient_slots[descriptor.parameter_node_index];
2251 let coefficient = descriptor.coefficient as f64;
2252 for (accum_item, gradient_item) in accum.iter_mut().zip(parameter_gradient.iter()) {
2253 *accum_item +=
2254 (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2255 }
2256 accum
2257 })
2258 }
2259 fn evaluate_cached_weighted_gradient_sum_lowered(
2260 &self,
2261 state: &CachedIntegralCacheState,
2262 lowered_artifacts: &LoweredArtifactCacheState,
2263 amplitude_values: &[Complex64],
2264 amplitude_gradients: &[DVector<Complex64>],
2265 grad_dim: usize,
2266 ) -> Option<DVector<f64>> {
2267 let max_value_slots = lowered_artifacts
2268 .lowered_parameter_factors
2269 .iter()
2270 .filter_map(|runtime| {
2271 runtime
2272 .as_ref()
2273 .and_then(|runtime| runtime.gradient_program())
2274 .map(|program| program.scratch_slots())
2275 })
2276 .max()
2277 .unwrap_or(0);
2278 let mut value_slots = vec![Complex64::ZERO; max_value_slots];
2279 let mut gradient_slots = vec![Complex64::ZERO; max_value_slots * grad_dim];
2280 let mut total = DVector::zeros(grad_dim);
2281 for (descriptor, runtime) in state
2282 .values
2283 .iter()
2284 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2285 {
2286 let parameter_gradient = runtime
2287 .as_ref()
2288 .and_then(|runtime| runtime.gradient_program())
2289 .map(|program| {
2290 program.evaluate_gradient_into_flat(
2291 amplitude_values,
2292 amplitude_gradients,
2293 &mut value_slots[..program.scratch_slots()],
2294 &mut gradient_slots[..program.scratch_slots() * grad_dim],
2295 grad_dim,
2296 )
2297 })?;
2298 let coefficient = descriptor.coefficient as f64;
2299 for (accum_item, gradient_item) in total.iter_mut().zip(parameter_gradient.iter()) {
2300 *accum_item += (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2301 }
2302 }
2303 Some(total)
2304 }
2305 fn evaluate_residual_value_ir(
2306 &self,
2307 state: &CachedIntegralCacheState,
2308 amplitude_values: &[Complex64],
2309 ) -> Complex64 {
2310 let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2311 for descriptor in &state.values {
2312 if descriptor.mul_node_index < zeroed_nodes.len() {
2313 zeroed_nodes[descriptor.mul_node_index] = true;
2314 }
2315 }
2316 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2317 state.expression_ir.evaluate_into_with_zeroed_nodes(
2318 amplitude_values,
2319 &mut value_slots,
2320 &zeroed_nodes,
2321 )
2322 }
2323 fn evaluate_residual_gradient_ir(
2324 &self,
2325 state: &CachedIntegralCacheState,
2326 amplitude_values: &[Complex64],
2327 amplitude_gradients: &[DVector<Complex64>],
2328 grad_dim: usize,
2329 ) -> DVector<Complex64> {
2330 let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2331 for descriptor in &state.values {
2332 if descriptor.mul_node_index < zeroed_nodes.len() {
2333 zeroed_nodes[descriptor.mul_node_index] = true;
2334 }
2335 }
2336 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2337 let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2338 state
2339 .expression_ir
2340 .evaluate_gradient_into_with_zeroed_nodes(
2341 amplitude_values,
2342 amplitude_gradients,
2343 &mut value_slots,
2344 &mut gradient_slots,
2345 &zeroed_nodes,
2346 )
2347 }
2348
2349 fn evaluate_weighted_value_sum_local_components(
2350 &self,
2351 parameters: &[f64],
2352 ) -> LadduResult<(f64, f64)> {
2353 let resources = self.resources.read();
2354 let parameters = resources.parameter_map.assemble(parameters)?;
2355 let amplitude_len = self.amplitude_use_sites.len();
2356 let state = self.ensure_cached_integral_cache_state(&resources)?;
2357 let lowered_artifacts = self.active_lowered_artifacts();
2358 let residual_value_slot_count = lowered_artifacts
2359 .as_ref()
2360 .and_then(|artifacts| {
2361 artifacts
2362 .residual_runtime
2363 .as_ref()
2364 .map(|runtime| runtime.value_program())
2365 .map(|program| program.scratch_slots())
2366 })
2367 .unwrap_or_else(|| self.expression_slot_count());
2368 let residual_value_program = lowered_artifacts
2369 .as_ref()
2370 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2371 .map(|runtime| runtime.value_program());
2372 let cached_parameter_indices = &state.execution_sets.cached_parameter_amplitudes;
2373 let residual_active_indices = &state.execution_sets.residual_amplitudes;
2374 debug_assert!(cached_parameter_indices.iter().all(|&index| resources
2375 .active
2376 .get(index)
2377 .copied()
2378 .unwrap_or(false)));
2379 debug_assert!(residual_active_indices.iter().all(|&index| resources
2380 .active
2381 .get(index)
2382 .copied()
2383 .unwrap_or(false)));
2384 let cached_value_sum = {
2385 if let Some(cache) = resources.caches.first() {
2386 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2387 self.fill_amplitude_values(
2388 &mut amplitude_values,
2389 cached_parameter_indices,
2390 ¶meters,
2391 cache,
2392 );
2393 lowered_artifacts
2394 .as_ref()
2395 .and_then(|artifacts| {
2396 self.evaluate_cached_weighted_value_sum_lowered(
2397 &state,
2398 artifacts,
2399 &litude_values,
2400 )
2401 })
2402 .unwrap_or_else(|| {
2403 self.evaluate_cached_weighted_value_sum_ir(&state, &litude_values)
2404 })
2405 } else {
2406 0.0
2407 }
2408 };
2409
2410 #[cfg(feature = "rayon")]
2411 let residual_sum: f64 = {
2412 resources
2413 .caches
2414 .par_iter()
2415 .zip(self.dataset.weights_local().par_iter())
2416 .map_init(
2417 || {
2418 (
2419 vec![Complex64::ZERO; amplitude_len],
2420 vec![Complex64::ZERO; residual_value_slot_count],
2421 )
2422 },
2423 |(amplitude_values, value_slots), (cache, event)| {
2424 self.fill_amplitude_values(
2425 amplitude_values,
2426 residual_active_indices,
2427 ¶meters,
2428 cache,
2429 );
2430 {
2431 let value = residual_value_program
2432 .as_ref()
2433 .map(|program| {
2434 program.evaluate_into(
2435 amplitude_values,
2436 &mut value_slots[..program.scratch_slots()],
2437 )
2438 })
2439 .unwrap_or_else(|| {
2440 self.evaluate_residual_value_ir(&state, amplitude_values)
2441 });
2442 *event * value.re
2443 }
2444 },
2445 )
2446 .sum()
2447 };
2448
2449 #[cfg(not(feature = "rayon"))]
2450 let residual_sum: f64 = {
2451 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2452 let mut value_slots = vec![Complex64::ZERO; residual_value_slot_count];
2453 resources
2454 .caches
2455 .iter()
2456 .zip(self.dataset.weights_local().iter())
2457 .map(|(cache, event)| {
2458 self.fill_amplitude_values(
2459 &mut amplitude_values,
2460 &residual_active_indices,
2461 ¶meters,
2462 cache,
2463 );
2464 {
2465 let value = residual_value_program
2466 .as_ref()
2467 .map(|program| {
2468 program.evaluate_into(
2469 &litude_values,
2470 &mut value_slots[..program.scratch_slots()],
2471 )
2472 })
2473 .unwrap_or_else(|| {
2474 self.evaluate_residual_value_ir(&state, &litude_values)
2475 });
2476 *event * value.re
2477 }
2478 })
2479 .sum()
2480 };
2481 Ok((residual_sum, cached_value_sum))
2482 }
2483
2484 pub fn evaluate_weighted_value_sum_local(&self, parameters: &[f64]) -> LadduResult<f64> {
2488 let (residual_sum, cached_value_sum) =
2489 self.evaluate_weighted_value_sum_local_components(parameters)?;
2490 Ok(residual_sum + cached_value_sum)
2491 }
2492
2493 #[cfg(feature = "mpi")]
2494 pub fn evaluate_weighted_value_sum_mpi(
2498 &self,
2499 parameters: &[f64],
2500 world: &SimpleCommunicator,
2501 ) -> LadduResult<f64> {
2502 let (residual_sum_local, cached_value_sum_local) =
2503 self.evaluate_weighted_value_sum_local_components(parameters)?;
2504 let mut residual_sum = 0.0;
2505 world.all_reduce_into(
2506 &residual_sum_local,
2507 &mut residual_sum,
2508 mpi::collective::SystemOperation::sum(),
2509 );
2510 let mut cached_value_sum = 0.0;
2511 world.all_reduce_into(
2512 &cached_value_sum_local,
2513 &mut cached_value_sum,
2514 mpi::collective::SystemOperation::sum(),
2515 );
2516 Ok(residual_sum + cached_value_sum)
2517 }
2518
2519 fn evaluate_weighted_gradient_sum_local_components(
2523 &self,
2524 parameters: &[f64],
2525 ) -> LadduResult<(DVector<f64>, DVector<f64>)> {
2526 let resources = self.resources.read();
2527 let parameters = resources.parameter_map.assemble(parameters)?;
2528 let amplitude_len = self.amplitude_use_sites.len();
2529 let grad_dim = parameters.len();
2530 let state = self.ensure_cached_integral_cache_state(&resources)?;
2531 let lowered_artifacts = self.active_lowered_artifacts();
2532 let active_index_set = resources.active_indices();
2533 let cached_parameter_indices = state
2534 .execution_sets
2535 .cached_parameter_amplitudes
2536 .iter()
2537 .copied()
2538 .filter(|index| active_index_set.binary_search(index).is_ok())
2539 .collect::<Vec<_>>();
2540 let residual_active_indices = state
2541 .execution_sets
2542 .residual_amplitudes
2543 .iter()
2544 .copied()
2545 .filter(|index| active_index_set.binary_search(index).is_ok())
2546 .collect::<Vec<_>>();
2547 let mut cached_parameter_mask = vec![false; amplitude_len];
2548 for &index in &cached_parameter_indices {
2549 cached_parameter_mask[index] = true;
2550 }
2551 let mut residual_active_mask = vec![false; amplitude_len];
2552 for &index in &residual_active_indices {
2553 residual_active_mask[index] = true;
2554 }
2555 let residual_gradient_program = lowered_artifacts
2556 .as_ref()
2557 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2558 .map(|runtime| runtime.gradient_program());
2559 let residual_gradient_slot_count = residual_gradient_program
2560 .as_ref()
2561 .map(|program| program.scratch_slots())
2562 .unwrap_or_else(|| state.expression_ir.node_count());
2563 let cached_term_sum = {
2564 if let Some(cache) = resources.caches.first() {
2565 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2566 self.fill_amplitude_values(
2567 &mut amplitude_values,
2568 &cached_parameter_indices,
2569 ¶meters,
2570 cache,
2571 );
2572 let mut amplitude_gradients = (0..amplitude_len)
2573 .map(|_| DVector::zeros(grad_dim))
2574 .collect::<Vec<_>>();
2575 self.fill_amplitude_gradients(
2576 &mut amplitude_gradients,
2577 &cached_parameter_mask,
2578 ¶meters,
2579 cache,
2580 );
2581 lowered_artifacts
2582 .as_ref()
2583 .and_then(|artifacts| {
2584 self.evaluate_cached_weighted_gradient_sum_lowered(
2585 &state,
2586 artifacts,
2587 &litude_values,
2588 &litude_gradients,
2589 grad_dim,
2590 )
2591 })
2592 .unwrap_or_else(|| {
2593 self.evaluate_cached_weighted_gradient_sum_ir(
2594 &state,
2595 &litude_values,
2596 &litude_gradients,
2597 grad_dim,
2598 )
2599 })
2600 } else {
2601 DVector::zeros(grad_dim)
2602 }
2603 };
2604
2605 #[cfg(feature = "rayon")]
2606 let residual_sum = {
2607 resources
2608 .caches
2609 .par_iter()
2610 .zip(self.dataset.weights_local().par_iter())
2611 .map_init(
2612 || {
2613 (
2614 vec![Complex64::ZERO; amplitude_len],
2615 vec![DVector::zeros(grad_dim); amplitude_len],
2616 vec![Complex64::ZERO; residual_gradient_slot_count],
2617 vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim],
2618 )
2619 },
2620 |(amplitude_values, gradient_values, value_slots, gradient_slots),
2621 (cache, event)| {
2622 self.fill_amplitude_values_and_gradients(
2623 amplitude_values,
2624 gradient_values,
2625 &residual_active_indices,
2626 &residual_active_mask,
2627 ¶meters,
2628 cache,
2629 );
2630 let gradient = residual_gradient_program
2631 .as_ref()
2632 .map(|program| {
2633 program.evaluate_gradient_into_flat(
2634 amplitude_values,
2635 gradient_values,
2636 value_slots,
2637 gradient_slots,
2638 grad_dim,
2639 )
2640 })
2641 .unwrap_or_else(|| {
2642 self.evaluate_residual_gradient_ir(
2643 &state,
2644 amplitude_values,
2645 gradient_values,
2646 grad_dim,
2647 )
2648 });
2649 gradient.map(|value| value.re).scale(*event)
2650 },
2651 )
2652 .reduce(
2653 || DVector::zeros(grad_dim),
2654 |mut accum, value| {
2655 accum += value;
2656 accum
2657 },
2658 )
2659 };
2660
2661 #[cfg(not(feature = "rayon"))]
2662 let residual_sum = {
2663 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2664 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
2665 let mut value_slots = vec![Complex64::ZERO; residual_gradient_slot_count];
2666 let mut gradient_slots = vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim];
2667 resources
2668 .caches
2669 .iter()
2670 .zip(self.dataset.weights_local().iter())
2671 .map(|(cache, event)| {
2672 self.fill_amplitude_values_and_gradients(
2673 &mut amplitude_values,
2674 &mut gradient_values,
2675 &residual_active_indices,
2676 &residual_active_mask,
2677 ¶meters,
2678 cache,
2679 );
2680 let gradient = residual_gradient_program
2681 .as_ref()
2682 .map(|program| {
2683 program.evaluate_gradient_into_flat(
2684 &litude_values,
2685 &gradient_values,
2686 &mut value_slots,
2687 &mut gradient_slots,
2688 grad_dim,
2689 )
2690 })
2691 .unwrap_or_else(|| {
2692 self.evaluate_residual_gradient_ir(
2693 &state,
2694 &litude_values,
2695 &gradient_values,
2696 grad_dim,
2697 )
2698 });
2699 gradient.map(|value| value.re).scale(*event)
2700 })
2701 .sum()
2702 };
2703 Ok((residual_sum, cached_term_sum))
2704 }
2705
2706 pub fn evaluate_weighted_gradient_sum_local(
2710 &self,
2711 parameters: &[f64],
2712 ) -> LadduResult<DVector<f64>> {
2713 let (residual_sum, cached_term_sum) =
2714 self.evaluate_weighted_gradient_sum_local_components(parameters)?;
2715 Ok(residual_sum + cached_term_sum)
2716 }
2717
2718 #[cfg(feature = "mpi")]
2719 pub fn evaluate_weighted_gradient_sum_mpi(
2723 &self,
2724 parameters: &[f64],
2725 world: &SimpleCommunicator,
2726 ) -> LadduResult<DVector<f64>> {
2727 let (residual_sum_local, cached_term_sum_local) =
2728 self.evaluate_weighted_gradient_sum_local_components(parameters)?;
2729 let mut residual_sum = vec![0.0; residual_sum_local.len()];
2730 world.all_reduce_into(
2731 residual_sum_local.as_slice(),
2732 &mut residual_sum,
2733 mpi::collective::SystemOperation::sum(),
2734 );
2735 let mut cached_term_sum = vec![0.0; cached_term_sum_local.len()];
2736 world.all_reduce_into(
2737 cached_term_sum_local.as_slice(),
2738 &mut cached_term_sum,
2739 mpi::collective::SystemOperation::sum(),
2740 );
2741 let mut total = DVector::from_vec(residual_sum);
2742 total += DVector::from_vec(cached_term_sum);
2743 Ok(total)
2744 }
2745
2746 pub fn evaluate_expression_value_with_scratch(
2747 &self,
2748 amplitude_values: &[Complex64],
2749 scratch: &mut [Complex64],
2750 ) -> Complex64 {
2751 self.evaluate_expression_runtime_value_with_scratch(amplitude_values, scratch)
2752 }
2753
2754 pub fn evaluate_expression_gradient_with_scratch(
2755 &self,
2756 amplitude_values: &[Complex64],
2757 gradient_values: &[DVector<Complex64>],
2758 value_scratch: &mut [Complex64],
2759 gradient_scratch: &mut [DVector<Complex64>],
2760 ) -> DVector<Complex64> {
2761 self.evaluate_expression_runtime_gradient_with_scratch(
2762 amplitude_values,
2763 gradient_values,
2764 value_scratch,
2765 gradient_scratch,
2766 )
2767 }
2768
2769 pub fn evaluate_expression_value_gradient_with_scratch(
2770 &self,
2771 amplitude_values: &[Complex64],
2772 gradient_values: &[DVector<Complex64>],
2773 value_scratch: &mut [Complex64],
2774 gradient_scratch: &mut [DVector<Complex64>],
2775 ) -> (Complex64, DVector<Complex64>) {
2776 self.evaluate_expression_runtime_value_gradient_with_scratch(
2777 amplitude_values,
2778 gradient_values,
2779 value_scratch,
2780 gradient_scratch,
2781 )
2782 }
2783
2784 pub fn evaluate_expression_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
2785 self.evaluate_expression_runtime_value(amplitude_values)
2786 }
2787
2788 pub fn evaluate_expression_gradient(
2789 &self,
2790 amplitude_values: &[Complex64],
2791 gradient_values: &[DVector<Complex64>],
2792 ) -> DVector<Complex64> {
2793 self.evaluate_expression_runtime_gradient(amplitude_values, gradient_values)
2794 }
2795
2796 pub fn parameters(&self) -> ParameterMap {
2798 self.resources.read().parameters()
2799 }
2800
2801 pub fn n_free(&self) -> usize {
2803 self.resources.read().n_free_parameters()
2804 }
2805
2806 pub fn n_fixed(&self) -> usize {
2808 self.resources.read().n_fixed_parameters()
2809 }
2810
2811 pub fn n_parameters(&self) -> usize {
2813 self.resources.read().n_parameters()
2814 }
2815
2816 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
2817 self.resources.read().fix_parameter(name, value)
2818 }
2819
2820 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
2821 self.resources.read().free_parameter(name)
2822 }
2823
2824 pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
2825 self.resources.write().rename_parameter(old, new)
2826 }
2827
2828 pub fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
2829 self.resources.write().rename_parameters(mapping)
2830 }
2831
2832 pub fn activate<T: AsRef<str>>(&self, name: T) {
2834 self.resources.write().activate(name);
2835 self.refresh_runtime_specializations();
2836 }
2837 pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2839 self.resources.write().activate_strict(name)?;
2840 self.refresh_runtime_specializations();
2841 Ok(())
2842 }
2843
2844 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
2846 self.resources.write().activate_many(names);
2847 self.refresh_runtime_specializations();
2848 }
2849 pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2851 self.resources.write().activate_many_strict(names)?;
2852 self.refresh_runtime_specializations();
2853 Ok(())
2854 }
2855
2856 pub fn activate_all(&self) {
2858 self.resources.write().activate_all();
2859 self.refresh_runtime_specializations();
2860 }
2861
2862 pub fn deactivate<T: AsRef<str>>(&self, name: T) {
2864 self.resources.write().deactivate(name);
2865 self.refresh_runtime_specializations();
2866 }
2867
2868 pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2870 self.resources.write().deactivate_strict(name)?;
2871 self.refresh_runtime_specializations();
2872 Ok(())
2873 }
2874
2875 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
2877 self.resources.write().deactivate_many(names);
2878 self.refresh_runtime_specializations();
2879 }
2880 pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2882 self.resources.write().deactivate_many_strict(names)?;
2883 self.refresh_runtime_specializations();
2884 Ok(())
2885 }
2886
2887 pub fn deactivate_all(&self) {
2889 self.resources.write().deactivate_all();
2890 self.refresh_runtime_specializations();
2891 }
2892
2893 pub fn isolate<T: AsRef<str>>(&self, name: T) {
2895 self.resources.write().isolate(name);
2896 self.refresh_runtime_specializations();
2897 }
2898
2899 pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
2901 self.resources.write().isolate_strict(name)?;
2902 self.refresh_runtime_specializations();
2903 Ok(())
2904 }
2905
2906 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
2908 self.resources.write().isolate_many(names);
2909 self.refresh_runtime_specializations();
2910 }
2911
2912 pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
2914 self.resources.write().isolate_many_strict(names)?;
2915 self.refresh_runtime_specializations();
2916 Ok(())
2917 }
2918
2919 pub fn active_mask(&self) -> Vec<bool> {
2921 self.resources.read().active.clone()
2922 }
2923
2924 pub fn set_active_mask(&self, mask: &[bool]) -> LadduResult<()> {
2926 let resources = {
2927 let mut resources = self.resources.write();
2928 if mask.len() != resources.active.len() {
2929 return Err(LadduError::LengthMismatch {
2930 context: "active amplitude mask".to_string(),
2931 expected: resources.active.len(),
2932 actual: mask.len(),
2933 });
2934 }
2935 resources.apply_active_mask(mask)?;
2936 resources.clone()
2937 };
2938 self.rebuild_runtime_specializations(&resources);
2939 Ok(())
2940 }
2941
2942 pub fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
2950 let resources = self.resources.read();
2951 let parameters = resources.parameter_map.assemble(parameters)?;
2952 let amplitude_len = self.amplitude_use_sites.len();
2953 let active_indices = resources.active_indices().to_vec();
2954 let slot_count = self.expression_value_slot_count();
2955 let program_snapshot = self.expression_value_program_snapshot();
2956 #[cfg(feature = "rayon")]
2957 {
2958 Ok(resources
2959 .caches
2960 .par_iter()
2961 .map_init(
2962 || {
2963 (
2964 vec![Complex64::ZERO; amplitude_len],
2965 vec![Complex64::ZERO; slot_count],
2966 )
2967 },
2968 |(amplitude_values, expr_slots), cache| {
2969 self.fill_amplitude_values(
2970 amplitude_values,
2971 &active_indices,
2972 ¶meters,
2973 cache,
2974 );
2975 self.evaluate_expression_value_with_program_snapshot(
2976 &program_snapshot,
2977 amplitude_values,
2978 expr_slots,
2979 )
2980 },
2981 )
2982 .collect())
2983 }
2984 #[cfg(not(feature = "rayon"))]
2985 {
2986 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
2987 let mut expr_slots = vec![Complex64::ZERO; slot_count];
2988 Ok(resources
2989 .caches
2990 .iter()
2991 .map(|cache| {
2992 self.fill_amplitude_values(
2993 &mut amplitude_values,
2994 &active_indices,
2995 ¶meters,
2996 cache,
2997 );
2998 self.evaluate_expression_value_with_program_snapshot(
2999 &program_snapshot,
3000 &litude_values,
3001 &mut expr_slots,
3002 )
3003 })
3004 .collect())
3005 }
3006 }
3007
3008 pub fn evaluate_local_with_active_mask(
3010 &self,
3011 parameters: &[f64],
3012 active_mask: &[bool],
3013 ) -> LadduResult<Vec<Complex64>> {
3014 let resources = self.resources.read();
3015 if active_mask.len() != resources.active.len() {
3016 return Err(LadduError::LengthMismatch {
3017 context: "active amplitude mask".to_string(),
3018 expected: resources.active.len(),
3019 actual: active_mask.len(),
3020 });
3021 }
3022 let parameters = resources.parameter_map.assemble(parameters)?;
3023 let amplitude_len = self.amplitude_use_sites.len();
3024 let active_indices = active_mask
3025 .iter()
3026 .enumerate()
3027 .filter_map(|(index, &active)| if active { Some(index) } else { None })
3028 .collect::<Vec<_>>();
3029 let program_snapshot =
3030 self.expression_value_program_snapshot_for_active_mask(active_mask)?;
3031 let slot_count = self.expression_value_program_snapshot_slot_count(&program_snapshot);
3032 #[cfg(feature = "rayon")]
3033 {
3034 Ok(resources
3035 .caches
3036 .par_iter()
3037 .map_init(
3038 || {
3039 (
3040 vec![Complex64::ZERO; amplitude_len],
3041 vec![Complex64::ZERO; slot_count],
3042 )
3043 },
3044 |(amplitude_values, expr_slots), cache| {
3045 self.fill_amplitude_values(
3046 amplitude_values,
3047 &active_indices,
3048 ¶meters,
3049 cache,
3050 );
3051 self.evaluate_expression_value_with_program_snapshot(
3052 &program_snapshot,
3053 amplitude_values,
3054 expr_slots,
3055 )
3056 },
3057 )
3058 .collect())
3059 }
3060 #[cfg(not(feature = "rayon"))]
3061 {
3062 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3063 let mut expr_slots = vec![Complex64::ZERO; slot_count];
3064 Ok(resources
3065 .caches
3066 .iter()
3067 .map(|cache| {
3068 self.fill_amplitude_values(
3069 &mut amplitude_values,
3070 &active_indices,
3071 ¶meters,
3072 cache,
3073 );
3074 self.evaluate_expression_value_with_program_snapshot(
3075 &program_snapshot,
3076 &litude_values,
3077 &mut expr_slots,
3078 )
3079 })
3080 .collect())
3081 }
3082 }
3083
3084 #[cfg(feature = "execution-context-prototype")]
3086 pub fn evaluate_local_with_ctx(
3087 &self,
3088 parameters: &[f64],
3089 execution_context: &ExecutionContext,
3090 ) -> Vec<Complex64> {
3091 let resources = self.resources.read();
3092 let parameters = resources
3093 .parameter_map
3094 .assemble(parameters)
3095 .expect("parameter slice must match evaluator resources");
3096 let amplitude_len = self.amplitude_use_sites.len();
3097 let active_indices = resources.active_indices().to_vec();
3098 let slot_count = self.expression_value_slot_count();
3099 let program_snapshot = self.expression_value_program_snapshot();
3100 #[cfg(feature = "rayon")]
3101 {
3102 if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3103 return execution_context.install(|| {
3104 resources
3105 .caches
3106 .par_iter()
3107 .map_init(
3108 || {
3109 (
3110 vec![Complex64::ZERO; amplitude_len],
3111 vec![Complex64::ZERO; slot_count],
3112 )
3113 },
3114 |(amplitude_values, expr_slots), cache| {
3115 self.fill_amplitude_values(
3116 amplitude_values,
3117 &active_indices,
3118 ¶meters,
3119 cache,
3120 );
3121 self.evaluate_expression_value_with_program_snapshot(
3122 &program_snapshot,
3123 amplitude_values,
3124 expr_slots,
3125 )
3126 },
3127 )
3128 .collect()
3129 });
3130 }
3131 }
3132 execution_context.with_scratch(|scratch| {
3133 let (amplitude_values, expr_slots) =
3134 scratch.reserve_value_workspaces(amplitude_len, slot_count);
3135 resources
3136 .caches
3137 .iter()
3138 .map(|cache| {
3139 self.fill_amplitude_values(
3140 amplitude_values,
3141 &active_indices,
3142 ¶meters,
3143 cache,
3144 );
3145 self.evaluate_expression_value_with_program_snapshot(
3146 &program_snapshot,
3147 amplitude_values,
3148 expr_slots,
3149 )
3150 })
3151 .collect()
3152 })
3153 }
3154
3155 #[cfg(feature = "mpi")]
3163 fn evaluate_mpi(
3164 &self,
3165 parameters: &[f64],
3166 world: &SimpleCommunicator,
3167 ) -> LadduResult<Vec<Complex64>> {
3168 let local_evaluation = self.evaluate_local(parameters)?;
3169 let n_events = self.dataset.n_events();
3170 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3171 let (counts, displs) = world.get_counts_displs(n_events);
3172 {
3173 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3176 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3177 }
3178 Ok(buffer)
3179 }
3180
3181 #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3182 fn evaluate_mpi_with_ctx(
3183 &self,
3184 parameters: &[f64],
3185 world: &SimpleCommunicator,
3186 execution_context: &ExecutionContext,
3187 ) -> Vec<Complex64> {
3188 let local_evaluation = self.evaluate_local_with_ctx(parameters, execution_context);
3189 let n_events = self.dataset.n_events();
3190 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3191 let (counts, displs) = world.get_counts_displs(n_events);
3192 {
3193 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3196 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3197 }
3198 buffer
3199 }
3200
3201 pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
3204 #[cfg(feature = "mpi")]
3205 {
3206 if let Some(world) = crate::mpi::get_world() {
3207 return self.evaluate_mpi(parameters, &world);
3208 }
3209 }
3210 self.evaluate_local(parameters)
3211 }
3212
3213 #[cfg(feature = "execution-context-prototype")]
3219 pub fn evaluate_with_ctx(
3220 &self,
3221 parameters: &[f64],
3222 execution_context: &ExecutionContext,
3223 ) -> Vec<Complex64> {
3224 #[cfg(feature = "mpi")]
3225 {
3226 if let Some(world) = crate::mpi::get_world() {
3227 return self.evaluate_mpi_with_ctx(parameters, &world, execution_context);
3228 }
3229 }
3230 self.evaluate_local_with_ctx(parameters, execution_context)
3231 }
3232
3233 pub fn evaluate_batch_local(
3236 &self,
3237 parameters: &[f64],
3238 indices: &[usize],
3239 ) -> LadduResult<Vec<Complex64>> {
3240 let resources = self.resources.read();
3241 let parameters = resources.parameter_map.assemble(parameters)?;
3242 let amplitude_len = self.amplitude_use_sites.len();
3243 let active_indices = resources.active_indices().to_vec();
3244 let slot_count = self.expression_value_slot_count();
3245 let program_snapshot = self.expression_value_program_snapshot();
3246 #[cfg(feature = "rayon")]
3247 {
3248 Ok(indices
3249 .par_iter()
3250 .map_init(
3251 || {
3252 (
3253 vec![Complex64::ZERO; amplitude_len],
3254 vec![Complex64::ZERO; slot_count],
3255 )
3256 },
3257 |(amplitude_values, expr_slots), &idx| {
3258 let cache = &resources.caches[idx];
3259 self.fill_amplitude_values(
3260 amplitude_values,
3261 &active_indices,
3262 ¶meters,
3263 cache,
3264 );
3265 self.evaluate_expression_value_with_program_snapshot(
3266 &program_snapshot,
3267 amplitude_values,
3268 expr_slots,
3269 )
3270 },
3271 )
3272 .collect())
3273 }
3274 #[cfg(not(feature = "rayon"))]
3275 {
3276 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3277 let mut expr_slots = vec![Complex64::ZERO; slot_count];
3278 Ok(indices
3279 .iter()
3280 .map(|&idx| {
3281 let cache = &resources.caches[idx];
3282 self.fill_amplitude_values(
3283 &mut amplitude_values,
3284 &active_indices,
3285 ¶meters,
3286 cache,
3287 );
3288 self.evaluate_expression_value_with_program_snapshot(
3289 &program_snapshot,
3290 &litude_values,
3291 &mut expr_slots,
3292 )
3293 })
3294 .collect())
3295 }
3296 }
3297
3298 #[cfg(feature = "mpi")]
3301 fn evaluate_batch_mpi(
3302 &self,
3303 parameters: &[f64],
3304 indices: &[usize],
3305 world: &SimpleCommunicator,
3306 ) -> LadduResult<Vec<Complex64>> {
3307 let total = self.dataset.n_events();
3308 let locals = world.locals_from_globals(indices, total);
3309 let local_evaluation = self.evaluate_batch_local(parameters, &locals)?;
3310 Ok(world.all_gather_batched_partitioned(&local_evaluation, indices, total, None))
3311 }
3312
3313 pub fn evaluate_batch(
3316 &self,
3317 parameters: &[f64],
3318 indices: &[usize],
3319 ) -> LadduResult<Vec<Complex64>> {
3320 #[cfg(feature = "mpi")]
3321 {
3322 if let Some(world) = crate::mpi::get_world() {
3323 return self.evaluate_batch_mpi(parameters, indices, &world);
3324 }
3325 }
3326 self.evaluate_batch_local(parameters, indices)
3327 }
3328
3329 pub fn evaluate_gradient_local(
3337 &self,
3338 parameters: &[f64],
3339 ) -> LadduResult<Vec<DVector<Complex64>>> {
3340 let resources = self.resources.read();
3341 let parameters = resources.parameter_map.assemble(parameters)?;
3342 let amplitude_len = self.amplitude_use_sites.len();
3343 let grad_dim = parameters.len();
3344 let active_indices = resources.active_indices().to_vec();
3345 let lowered_runtime = self.lowered_runtime();
3346 let gradient_program = lowered_runtime.gradient_program();
3347 let slot_count = self.expression_gradient_slot_count();
3348 #[cfg(feature = "rayon")]
3349 {
3350 Ok(resources
3351 .caches
3352 .par_iter()
3353 .map_init(
3354 || {
3355 (
3356 vec![Complex64::ZERO; amplitude_len],
3357 vec![DVector::zeros(grad_dim); amplitude_len],
3358 vec![Complex64::ZERO; slot_count],
3359 vec![Complex64::ZERO; slot_count * grad_dim],
3360 )
3361 },
3362 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3363 self.fill_amplitude_values_and_gradients(
3364 amplitude_values,
3365 gradient_values,
3366 &active_indices,
3367 &resources.active,
3368 ¶meters,
3369 cache,
3370 );
3371 gradient_program.evaluate_gradient_into_flat(
3372 amplitude_values,
3373 gradient_values,
3374 value_slots,
3375 gradient_slots,
3376 grad_dim,
3377 )
3378 },
3379 )
3380 .collect())
3381 }
3382 #[cfg(not(feature = "rayon"))]
3383 {
3384 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3385 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3386 let mut value_slots = vec![Complex64::ZERO; slot_count];
3387 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3388 Ok(resources
3389 .caches
3390 .iter()
3391 .map(|cache| {
3392 self.fill_amplitude_values_and_gradients(
3393 &mut amplitude_values,
3394 &mut gradient_values,
3395 &active_indices,
3396 &resources.active,
3397 ¶meters,
3398 cache,
3399 );
3400 gradient_program.evaluate_gradient_into_flat(
3401 &litude_values,
3402 &gradient_values,
3403 &mut value_slots,
3404 &mut gradient_slots,
3405 grad_dim,
3406 )
3407 })
3408 .collect())
3409 }
3410 }
3411
3412 #[cfg(feature = "execution-context-prototype")]
3414 pub fn evaluate_gradient_local_with_ctx(
3415 &self,
3416 parameters: &[f64],
3417 execution_context: &ExecutionContext,
3418 ) -> Vec<DVector<Complex64>> {
3419 let resources = self.resources.read();
3420 let parameters = resources
3421 .parameter_map
3422 .assemble(parameters)
3423 .expect("parameter slice must match evaluator resources");
3424 let amplitude_len = self.amplitude_use_sites.len();
3425 let grad_dim = parameters.len();
3426 let active_indices = resources.active_indices().to_vec();
3427 let slot_count = self.expression_slot_count();
3428 #[cfg(feature = "rayon")]
3429 {
3430 if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3431 return execution_context.install(|| {
3432 resources
3433 .caches
3434 .par_iter()
3435 .map_init(
3436 || {
3437 (
3438 vec![Complex64::ZERO; amplitude_len],
3439 vec![DVector::zeros(grad_dim); amplitude_len],
3440 vec![Complex64::ZERO; slot_count],
3441 vec![DVector::zeros(grad_dim); slot_count],
3442 )
3443 },
3444 |(amplitude_values, gradient_values, value_slots, gradient_slots),
3445 cache| {
3446 self.evaluate_cache_gradient_with_scratch(
3447 amplitude_values,
3448 gradient_values,
3449 value_slots,
3450 gradient_slots,
3451 &active_indices,
3452 &resources.active,
3453 ¶meters,
3454 cache,
3455 )
3456 },
3457 )
3458 .collect()
3459 });
3460 }
3461 }
3462 execution_context.with_scratch(|scratch| {
3463 let (amplitude_values, value_slots, gradient_values, gradient_slots) =
3464 scratch.reserve_gradient_workspaces(amplitude_len, slot_count, grad_dim);
3465 resources
3466 .caches
3467 .iter()
3468 .map(|cache| {
3469 self.evaluate_cache_gradient_with_scratch(
3470 amplitude_values,
3471 gradient_values,
3472 value_slots,
3473 gradient_slots,
3474 &active_indices,
3475 &resources.active,
3476 ¶meters,
3477 cache,
3478 )
3479 })
3480 .collect()
3481 })
3482 }
3483
3484 #[cfg(feature = "mpi")]
3492 fn evaluate_gradient_mpi(
3493 &self,
3494 parameters: &[f64],
3495 world: &SimpleCommunicator,
3496 ) -> LadduResult<Vec<DVector<Complex64>>> {
3497 let local_evaluation = self.evaluate_gradient_local(parameters)?;
3498 let n_events = self.dataset.n_events();
3499 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
3500 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
3501 {
3502 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3505 world.all_gather_varcount_into(
3506 &local_evaluation
3507 .iter()
3508 .flat_map(|v| v.data.as_vec())
3509 .copied()
3510 .collect::<Vec<_>>(),
3511 &mut partitioned_buffer,
3512 );
3513 }
3514 Ok(buffer
3515 .chunks(parameters.len())
3516 .map(DVector::from_row_slice)
3517 .collect())
3518 }
3519
3520 #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3521 fn evaluate_gradient_mpi_with_ctx(
3522 &self,
3523 parameters: &[f64],
3524 world: &SimpleCommunicator,
3525 execution_context: &ExecutionContext,
3526 ) -> Vec<DVector<Complex64>> {
3527 let local_evaluation = self.evaluate_gradient_local_with_ctx(parameters, execution_context);
3528 let n_events = self.dataset.n_events();
3529 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
3530 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
3531 {
3532 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3535 world.all_gather_varcount_into(
3536 &local_evaluation
3537 .iter()
3538 .flat_map(|v| v.data.as_vec())
3539 .copied()
3540 .collect::<Vec<_>>(),
3541 &mut partitioned_buffer,
3542 );
3543 }
3544 buffer
3545 .chunks(parameters.len())
3546 .map(DVector::from_row_slice)
3547 .collect()
3548 }
3549
3550 pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<Vec<DVector<Complex64>>> {
3553 #[cfg(feature = "mpi")]
3554 {
3555 if let Some(world) = crate::mpi::get_world() {
3556 return self.evaluate_gradient_mpi(parameters, &world);
3557 }
3558 }
3559 self.evaluate_gradient_local(parameters)
3560 }
3561
3562 #[cfg(feature = "execution-context-prototype")]
3568 pub fn evaluate_gradient_with_ctx(
3569 &self,
3570 parameters: &[f64],
3571 execution_context: &ExecutionContext,
3572 ) -> Vec<DVector<Complex64>> {
3573 #[cfg(feature = "mpi")]
3574 {
3575 if let Some(world) = crate::mpi::get_world() {
3576 return self.evaluate_gradient_mpi_with_ctx(parameters, &world, execution_context);
3577 }
3578 }
3579 self.evaluate_gradient_local_with_ctx(parameters, execution_context)
3580 }
3581
3582 pub fn evaluate_gradient_batch_local(
3585 &self,
3586 parameters: &[f64],
3587 indices: &[usize],
3588 ) -> LadduResult<Vec<DVector<Complex64>>> {
3589 let resources = self.resources.read();
3590 let parameters = resources.parameter_map.assemble(parameters)?;
3591 let amplitude_len = self.amplitude_use_sites.len();
3592 let grad_dim = parameters.len();
3593 let active_indices = resources.active_indices().to_vec();
3594 let lowered_runtime = self.lowered_runtime();
3595 let gradient_program = lowered_runtime.gradient_program();
3596 let slot_count = self.expression_gradient_slot_count();
3597 #[cfg(feature = "rayon")]
3598 {
3599 Ok(indices
3600 .par_iter()
3601 .map_init(
3602 || {
3603 (
3604 vec![Complex64::ZERO; amplitude_len],
3605 vec![DVector::zeros(grad_dim); amplitude_len],
3606 vec![Complex64::ZERO; slot_count],
3607 vec![Complex64::ZERO; slot_count * grad_dim],
3608 )
3609 },
3610 |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
3611 let cache = &resources.caches[idx];
3612 self.fill_amplitude_values_and_gradients(
3613 amplitude_values,
3614 gradient_values,
3615 &active_indices,
3616 &resources.active,
3617 ¶meters,
3618 cache,
3619 );
3620 gradient_program.evaluate_gradient_into_flat(
3621 amplitude_values,
3622 gradient_values,
3623 value_slots,
3624 gradient_slots,
3625 grad_dim,
3626 )
3627 },
3628 )
3629 .collect())
3630 }
3631 #[cfg(not(feature = "rayon"))]
3632 {
3633 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3634 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3635 let mut value_slots = vec![Complex64::ZERO; slot_count];
3636 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3637 Ok(indices
3638 .iter()
3639 .map(|&idx| {
3640 let cache = &resources.caches[idx];
3641 self.fill_amplitude_values_and_gradients(
3642 &mut amplitude_values,
3643 &mut gradient_values,
3644 &active_indices,
3645 &resources.active,
3646 ¶meters,
3647 cache,
3648 );
3649 gradient_program.evaluate_gradient_into_flat(
3650 &litude_values,
3651 &gradient_values,
3652 &mut value_slots,
3653 &mut gradient_slots,
3654 grad_dim,
3655 )
3656 })
3657 .collect())
3658 }
3659 }
3660
3661 #[cfg(feature = "mpi")]
3664 fn evaluate_gradient_batch_mpi(
3665 &self,
3666 parameters: &[f64],
3667 indices: &[usize],
3668 world: &SimpleCommunicator,
3669 ) -> LadduResult<Vec<DVector<Complex64>>> {
3670 let total = self.dataset.n_events();
3671 let locals = world.locals_from_globals(indices, total);
3672 let flattened_local_evaluation = self
3673 .evaluate_gradient_batch_local(parameters, &locals)?
3674 .iter()
3675 .flat_map(|g| g.data.as_vec().to_vec())
3676 .collect::<Vec<Complex64>>();
3677 Ok(world
3678 .all_gather_batched_partitioned(
3679 &flattened_local_evaluation,
3680 indices,
3681 total,
3682 Some(parameters.len()),
3683 )
3684 .chunks(parameters.len())
3685 .map(DVector::from_row_slice)
3686 .collect())
3687 }
3688
3689 pub fn evaluate_gradient_batch(
3693 &self,
3694 parameters: &[f64],
3695 indices: &[usize],
3696 ) -> LadduResult<Vec<DVector<Complex64>>> {
3697 #[cfg(feature = "mpi")]
3698 {
3699 if let Some(world) = crate::mpi::get_world() {
3700 return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
3701 }
3702 }
3703 self.evaluate_gradient_batch_local(parameters, indices)
3704 }
3705
3706 pub fn evaluate_with_gradient_local(
3708 &self,
3709 parameters: &[f64],
3710 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3711 let resources = self.resources.read();
3712 let parameters = resources.parameter_map.assemble(parameters)?;
3713 let amplitude_len = self.amplitude_use_sites.len();
3714 let grad_dim = parameters.len();
3715 let active_indices = resources.active_indices().to_vec();
3716 let lowered_runtime = self.lowered_runtime();
3717 let value_gradient_program = lowered_runtime.value_gradient_program();
3718 let slot_count = self.expression_value_gradient_slot_count();
3719 #[cfg(feature = "rayon")]
3720 {
3721 Ok(resources
3722 .caches
3723 .par_iter()
3724 .map_init(
3725 || {
3726 (
3727 vec![Complex64::ZERO; amplitude_len],
3728 vec![DVector::zeros(grad_dim); amplitude_len],
3729 vec![Complex64::ZERO; slot_count],
3730 vec![Complex64::ZERO; slot_count * grad_dim],
3731 )
3732 },
3733 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3734 self.fill_amplitude_values_and_gradients(
3735 amplitude_values,
3736 gradient_values,
3737 &active_indices,
3738 &resources.active,
3739 ¶meters,
3740 cache,
3741 );
3742 value_gradient_program.evaluate_value_gradient_into_flat(
3743 amplitude_values,
3744 gradient_values,
3745 value_slots,
3746 gradient_slots,
3747 grad_dim,
3748 )
3749 },
3750 )
3751 .collect())
3752 }
3753 #[cfg(not(feature = "rayon"))]
3754 {
3755 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3756 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3757 let mut value_slots = vec![Complex64::ZERO; slot_count];
3758 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3759 Ok(resources
3760 .caches
3761 .iter()
3762 .map(|cache| {
3763 self.fill_amplitude_values_and_gradients(
3764 &mut amplitude_values,
3765 &mut gradient_values,
3766 &active_indices,
3767 &resources.active,
3768 ¶meters,
3769 cache,
3770 );
3771 value_gradient_program.evaluate_value_gradient_into_flat(
3772 &litude_values,
3773 &gradient_values,
3774 &mut value_slots,
3775 &mut gradient_slots,
3776 grad_dim,
3777 )
3778 })
3779 .collect())
3780 }
3781 }
3782
3783 pub fn evaluate_with_gradient_local_with_active_mask(
3785 &self,
3786 parameters: &[f64],
3787 active_mask: &[bool],
3788 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3789 let resources = self.resources.read();
3790 if active_mask.len() != resources.active.len() {
3791 return Err(LadduError::LengthMismatch {
3792 context: "active amplitude mask".to_string(),
3793 expected: resources.active.len(),
3794 actual: active_mask.len(),
3795 });
3796 }
3797 let parameters = resources.parameter_map.assemble(parameters)?;
3798 let amplitude_len = self.amplitude_use_sites.len();
3799 let grad_dim = parameters.len();
3800 let active_indices = active_mask
3801 .iter()
3802 .enumerate()
3803 .filter_map(|(index, &active)| if active { Some(index) } else { None })
3804 .collect::<Vec<_>>();
3805 let lowered_runtime = self.lower_expression_runtime_for_active_mask(active_mask)?;
3806 let slot_count = lowered_runtime.value_gradient_program().scratch_slots();
3807 #[cfg(feature = "rayon")]
3808 {
3809 Ok(resources
3810 .caches
3811 .par_iter()
3812 .map_init(
3813 || {
3814 (
3815 vec![Complex64::ZERO; amplitude_len],
3816 vec![DVector::zeros(grad_dim); amplitude_len],
3817 vec![Complex64::ZERO; slot_count],
3818 vec![Complex64::ZERO; slot_count * grad_dim],
3819 )
3820 },
3821 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
3822 self.fill_amplitude_values_and_gradients(
3823 amplitude_values,
3824 gradient_values,
3825 &active_indices,
3826 active_mask,
3827 ¶meters,
3828 cache,
3829 );
3830 lowered_runtime
3831 .value_gradient_program()
3832 .evaluate_value_gradient_into_flat(
3833 amplitude_values,
3834 gradient_values,
3835 value_slots,
3836 gradient_slots,
3837 grad_dim,
3838 )
3839 },
3840 )
3841 .collect())
3842 }
3843 #[cfg(not(feature = "rayon"))]
3844 {
3845 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3846 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3847 let mut value_slots = vec![Complex64::ZERO; slot_count];
3848 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3849 Ok(resources
3850 .caches
3851 .iter()
3852 .map(|cache| {
3853 self.fill_amplitude_values_and_gradients(
3854 &mut amplitude_values,
3855 &mut gradient_values,
3856 &active_indices,
3857 active_mask,
3858 ¶meters,
3859 cache,
3860 );
3861 lowered_runtime
3862 .value_gradient_program()
3863 .evaluate_value_gradient_into_flat(
3864 &litude_values,
3865 &gradient_values,
3866 &mut value_slots,
3867 &mut gradient_slots,
3868 grad_dim,
3869 )
3870 })
3871 .collect())
3872 }
3873 }
3874
3875 pub fn evaluate_with_gradient_batch_local(
3877 &self,
3878 parameters: &[f64],
3879 indices: &[usize],
3880 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
3881 let resources = self.resources.read();
3882 let parameters = resources.parameter_map.assemble(parameters)?;
3883 let amplitude_len = self.amplitude_use_sites.len();
3884 let grad_dim = parameters.len();
3885 let active_indices = resources.active_indices().to_vec();
3886 let lowered_runtime = self.lowered_runtime();
3887 let value_gradient_program = lowered_runtime.value_gradient_program();
3888 let slot_count = self.expression_value_gradient_slot_count();
3889 #[cfg(feature = "rayon")]
3890 {
3891 Ok(indices
3892 .par_iter()
3893 .map_init(
3894 || {
3895 (
3896 vec![Complex64::ZERO; amplitude_len],
3897 vec![DVector::zeros(grad_dim); amplitude_len],
3898 vec![Complex64::ZERO; slot_count],
3899 vec![Complex64::ZERO; slot_count * grad_dim],
3900 )
3901 },
3902 |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
3903 let cache = &resources.caches[idx];
3904 self.fill_amplitude_values_and_gradients(
3905 amplitude_values,
3906 gradient_values,
3907 &active_indices,
3908 &resources.active,
3909 ¶meters,
3910 cache,
3911 );
3912 value_gradient_program.evaluate_value_gradient_into_flat(
3913 amplitude_values,
3914 gradient_values,
3915 value_slots,
3916 gradient_slots,
3917 grad_dim,
3918 )
3919 },
3920 )
3921 .collect())
3922 }
3923 #[cfg(not(feature = "rayon"))]
3924 {
3925 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3926 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3927 let mut value_slots = vec![Complex64::ZERO; slot_count];
3928 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
3929 Ok(indices
3930 .iter()
3931 .map(|&idx| {
3932 let cache = &resources.caches[idx];
3933 self.fill_amplitude_values_and_gradients(
3934 &mut amplitude_values,
3935 &mut gradient_values,
3936 &active_indices,
3937 &resources.active,
3938 ¶meters,
3939 cache,
3940 );
3941 value_gradient_program.evaluate_value_gradient_into_flat(
3942 &litude_values,
3943 &gradient_values,
3944 &mut value_slots,
3945 &mut gradient_slots,
3946 grad_dim,
3947 )
3948 })
3949 .collect())
3950 }
3951 }
3952}
3953
3954#[cfg(test)]
3955mod tests {
3956 use approx::assert_relative_eq;
3957 #[cfg(feature = "mpi")]
3958 use mpi_test::mpi_test;
3959 use serde::{Deserialize, Serialize};
3960
3961 use super::*;
3962 use crate::{
3963 amplitude::{AmplitudeID, Tags, TestAmplitude},
3964 data::{test_dataset, test_event, DatasetMetadata, Event, EventData},
3965 parameter,
3966 parameters::Parameter,
3967 resources::{Cache, ParameterID, Parameters, Resources, ScalarID},
3968 vectors::Vec4,
3969 };
3970
3971 #[derive(Clone, Serialize, Deserialize)]
3972 pub struct ComplexScalar {
3973 name: String,
3974 re: Parameter,
3975 pid_re: ParameterID,
3976 im: Parameter,
3977 pid_im: ParameterID,
3978 }
3979
3980 impl ComplexScalar {
3981 #[allow(clippy::new_ret_no_self)]
3982 pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
3983 Self {
3984 name: name.to_string(),
3985 re,
3986 pid_re: Default::default(),
3987 im,
3988 pid_im: Default::default(),
3989 }
3990 .into_expression()
3991 }
3992 }
3993
3994 #[typetag::serde]
3995 impl Amplitude for ComplexScalar {
3996 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
3997 self.pid_re = resources.register_parameter(&self.re)?;
3998 self.pid_im = resources.register_parameter(&self.im)?;
3999 resources.register_amplitude(&self.name)
4000 }
4001
4002 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4003 Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
4004 }
4005
4006 fn compute_gradient(
4007 &self,
4008 parameters: &Parameters,
4009 _cache: &Cache,
4010 gradient: &mut DVector<Complex64>,
4011 ) {
4012 if let Some(ind) = parameters.free_index(self.pid_re) {
4013 gradient[ind] = Complex64::ONE;
4014 }
4015 if let Some(ind) = parameters.free_index(self.pid_im) {
4016 gradient[ind] = Complex64::I;
4017 }
4018 }
4019 }
4020
4021 #[derive(Clone, Serialize, Deserialize)]
4022 pub struct ParameterOnlyScalar {
4023 name: String,
4024 value: Parameter,
4025 pid: ParameterID,
4026 }
4027
4028 impl ParameterOnlyScalar {
4029 #[allow(clippy::new_ret_no_self)]
4030 pub fn new(name: &str, value: Parameter) -> LadduResult<Expression> {
4031 Self {
4032 name: name.to_string(),
4033 value,
4034 pid: Default::default(),
4035 }
4036 .into_expression()
4037 }
4038 }
4039
4040 #[typetag::serde]
4041 impl Amplitude for ParameterOnlyScalar {
4042 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4043 self.pid = resources.register_parameter(&self.value)?;
4044 resources.register_amplitude(&self.name)
4045 }
4046
4047 fn dependence_hint(&self) -> ExpressionDependence {
4048 ExpressionDependence::ParameterOnly
4049 }
4050
4051 fn real_valued_hint(&self) -> bool {
4052 true
4053 }
4054
4055 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4056 Complex64::new(parameters.get(self.pid), 0.0)
4057 }
4058 }
4059
4060 #[derive(Clone, Serialize, Deserialize)]
4061 pub struct CacheOnlyScalar {
4062 name: String,
4063 beam_energy: ScalarID,
4064 }
4065
4066 impl CacheOnlyScalar {
4067 #[allow(clippy::new_ret_no_self)]
4068 pub fn new(name: &str) -> LadduResult<Expression> {
4069 Self {
4070 name: name.to_string(),
4071 beam_energy: Default::default(),
4072 }
4073 .into_expression()
4074 }
4075 }
4076
4077 #[typetag::serde]
4078 impl Amplitude for CacheOnlyScalar {
4079 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4080 self.beam_energy =
4081 resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
4082 resources.register_amplitude(&self.name)
4083 }
4084
4085 fn dependence_hint(&self) -> ExpressionDependence {
4086 ExpressionDependence::CacheOnly
4087 }
4088
4089 fn real_valued_hint(&self) -> bool {
4090 true
4091 }
4092
4093 fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {
4094 cache.store_scalar(self.beam_energy, event.p4_at(0).e());
4095 }
4096
4097 fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
4098 Complex64::new(cache.get_scalar(self.beam_energy), 0.0)
4099 }
4100 }
4101
4102 #[derive(Clone, Copy)]
4103 enum DeterministicFixtureKind {
4104 Separable,
4105 Partial,
4106 NonSeparable,
4107 }
4108
4109 struct DeterministicFixture {
4110 expression: Expression,
4111 dataset: Arc<Dataset>,
4112 parameters: Vec<f64>,
4113 }
4114
4115 const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
4116 const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
4117
4118 fn deterministic_fixture_dataset() -> Arc<Dataset> {
4119 let metadata = Arc::new(DatasetMetadata::default());
4120 let events = vec![
4121 Arc::new(EventData {
4122 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
4123 aux: vec![],
4124 weight: 0.5,
4125 }),
4126 Arc::new(EventData {
4127 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
4128 aux: vec![],
4129 weight: -1.25,
4130 }),
4131 Arc::new(EventData {
4132 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
4133 aux: vec![],
4134 weight: 2.0,
4135 }),
4136 Arc::new(EventData {
4137 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
4138 aux: vec![],
4139 weight: 0.75,
4140 }),
4141 ];
4142 Arc::new(Dataset::new_with_metadata(events, metadata))
4143 }
4144
4145 fn make_deterministic_fixture(kind: DeterministicFixtureKind) -> DeterministicFixture {
4146 let dataset = deterministic_fixture_dataset();
4147 match kind {
4148 DeterministicFixtureKind::Separable => {
4149 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1"))
4150 .expect("separable p1 should build");
4151 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2"))
4152 .expect("separable p2 should build");
4153 let c1 = CacheOnlyScalar::new("c1").expect("separable c1 should build");
4154 let c2 = CacheOnlyScalar::new("c2").expect("separable c2 should build");
4155 DeterministicFixture {
4156 expression: (&p1 * &c1) + &(&p2 * &c2),
4157 dataset,
4158 parameters: vec![0.4, -0.3],
4159 }
4160 }
4161 DeterministicFixtureKind::Partial => {
4162 let p =
4163 ParameterOnlyScalar::new("p", parameter!("p")).expect("partial p should build");
4164 let c = CacheOnlyScalar::new("c").expect("partial c should build");
4165 let m = TestAmplitude::new("m", parameter!("mr"), parameter!("mi"))
4166 .expect("partial m should build");
4167 DeterministicFixture {
4168 expression: (&p * &c) + &m,
4169 dataset,
4170 parameters: vec![0.55, 0.2, -0.15],
4171 }
4172 }
4173 DeterministicFixtureKind::NonSeparable => {
4174 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i"))
4175 .expect("non-separable m1 should build");
4176 let m2 = TestAmplitude::new("m2", parameter!("m2r"), parameter!("m2i"))
4177 .expect("non-separable m2 should build");
4178 DeterministicFixture {
4179 expression: &m1 * &m2,
4180 dataset,
4181 parameters: vec![0.25, -0.4, 0.6, 0.1],
4182 }
4183 }
4184 }
4185 }
4186
4187 fn assert_weighted_sum_matches_eventwise_baseline(fixture: &DeterministicFixture) {
4188 let evaluator = fixture
4189 .expression
4190 .load(&fixture.dataset)
4191 .expect("fixture evaluator should load");
4192 let expected_value = evaluator
4193 .evaluate_local(&fixture.parameters)
4194 .expect("evaluation should succeed")
4195 .iter()
4196 .zip(fixture.dataset.weights_local().iter())
4197 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
4198 let expected_gradient = evaluator
4199 .evaluate_gradient_local(&fixture.parameters)
4200 .expect("evaluation should succeed")
4201 .iter()
4202 .zip(fixture.dataset.weights_local().iter())
4203 .fold(
4204 DVector::zeros(fixture.parameters.len()),
4205 |mut accum, (gradient, event)| {
4206 accum += gradient.map(|value| value.re).scale(*event);
4207 accum
4208 },
4209 );
4210 let actual_value = evaluator
4211 .evaluate_weighted_value_sum_local(&fixture.parameters)
4212 .expect("evaluation should succeed");
4213 let actual_gradient = evaluator
4214 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4215 .expect("evaluation should succeed");
4216 assert_relative_eq!(
4217 actual_value,
4218 expected_value,
4219 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4220 max_relative = DETERMINISTIC_STRICT_REL_TOL
4221 );
4222 for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
4223 assert_relative_eq!(
4224 *actual_item,
4225 *expected_item,
4226 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4227 max_relative = DETERMINISTIC_STRICT_REL_TOL
4228 );
4229 }
4230 }
4231 fn assert_mixed_normalization_components_match_combined_path(fixture: &DeterministicFixture) {
4232 let evaluator = fixture
4233 .expression
4234 .load(&fixture.dataset)
4235 .expect("fixture evaluator should load");
4236 let state = {
4237 let resources = evaluator.resources.read();
4238 evaluator.ensure_cached_integral_cache_state(&resources)
4239 }
4240 .expect("state should be available");
4241 assert!(
4242 !state.values.is_empty(),
4243 "fixture should exercise cached normalization terms"
4244 );
4245 assert!(
4246 !state.execution_sets.residual_amplitudes.is_empty(),
4247 "fixture should exercise residual normalization amplitudes"
4248 );
4249
4250 let (residual_value_sum, cached_value_sum) = evaluator
4251 .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4252 .expect("evaluation should succeed");
4253 assert!(residual_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4254 assert!(cached_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4255 let combined_value = evaluator
4256 .evaluate_weighted_value_sum_local(&fixture.parameters)
4257 .expect("evaluation should succeed");
4258 assert_relative_eq!(
4259 residual_value_sum + cached_value_sum,
4260 combined_value,
4261 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4262 max_relative = DETERMINISTIC_STRICT_REL_TOL
4263 );
4264
4265 let (residual_gradient_sum, cached_gradient_sum) = evaluator
4266 .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4267 .expect("evaluation should succeed");
4268 let combined_gradient = evaluator
4269 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4270 .expect("evaluation should succeed");
4271 assert!(residual_gradient_sum
4272 .iter()
4273 .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4274 assert!(cached_gradient_sum
4275 .iter()
4276 .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4277 for ((residual_item, cached_item), combined_item) in residual_gradient_sum
4278 .iter()
4279 .zip(cached_gradient_sum.iter())
4280 .zip(combined_gradient.iter())
4281 {
4282 assert_relative_eq!(
4283 residual_item + cached_item,
4284 *combined_item,
4285 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4286 max_relative = DETERMINISTIC_STRICT_REL_TOL
4287 );
4288 }
4289 }
4290
4291 #[test]
4292 fn test_deterministic_fixture_weighted_sums_stable_across_activation_mask_toggle() {
4293 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4294 let evaluator = fixture
4295 .expression
4296 .load(&fixture.dataset)
4297 .expect("fixture evaluator should load");
4298 let original_mask = evaluator.active_mask();
4299
4300 let original_value = evaluator
4301 .evaluate_weighted_value_sum_local(&fixture.parameters)
4302 .expect("evaluation should succeed");
4303
4304 evaluator.isolate_many(&["p", "c"]);
4305 assert_ne!(evaluator.active_mask(), original_mask);
4306
4307 evaluator
4308 .set_active_mask(&original_mask)
4309 .expect("original fixture active mask should restore");
4310 assert_eq!(evaluator.active_mask(), original_mask);
4311 let actual_value = evaluator
4312 .evaluate_weighted_value_sum_local(&fixture.parameters)
4313 .expect("evaluation should succeed");
4314 assert_relative_eq!(
4315 actual_value,
4316 original_value,
4317 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4318 max_relative = DETERMINISTIC_STRICT_REL_TOL
4319 );
4320 }
4321
4322 #[test]
4323 fn test_deterministic_fixtures_match_eventwise_weighted_sums() {
4324 let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4325 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4326 let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4327
4328 assert_weighted_sum_matches_eventwise_baseline(&separable);
4329 assert_weighted_sum_matches_eventwise_baseline(&partial);
4330 assert_weighted_sum_matches_eventwise_baseline(&non_separable);
4331 }
4332 #[test]
4333 fn test_deterministic_fixtures_cover_separable_partial_non_separable_models() {
4334 let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4335 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4336 let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4337
4338 let separable_evaluator = separable
4339 .expression
4340 .load(&separable.dataset)
4341 .expect("separable evaluator should load");
4342 let partial_evaluator = partial
4343 .expression
4344 .load(&partial.dataset)
4345 .expect("partial evaluator should load");
4346 let non_separable_evaluator = non_separable
4347 .expression
4348 .load(&non_separable.dataset)
4349 .expect("non-separable evaluator should load");
4350
4351 assert_eq!(
4352 separable_evaluator
4353 .expression_precomputed_cached_integrals()
4354 .expect("integrals should be computed")
4355 .len(),
4356 2
4357 );
4358 assert_eq!(
4359 partial_evaluator
4360 .expression_precomputed_cached_integrals()
4361 .expect("integrals should be computed")
4362 .len(),
4363 1
4364 );
4365 assert!(non_separable_evaluator
4366 .expression_precomputed_cached_integrals()
4367 .expect("integrals should be computed")
4368 .is_empty());
4369 }
4370 #[test]
4371 fn test_partial_fixture_combined_normalization_components_match_total() {
4372 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4373 assert_mixed_normalization_components_match_combined_path(&partial);
4374 }
4375 #[test]
4376 fn test_non_separable_fixture_normalization_components_stay_residual_only() {
4377 let fixture = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4378 let evaluator = fixture
4379 .expression
4380 .load(&fixture.dataset)
4381 .expect("fixture evaluator should load");
4382 let resources = evaluator.resources.read();
4383 let state = evaluator
4384 .ensure_cached_integral_cache_state(&resources)
4385 .expect("state should be available");
4386 assert!(state.values.is_empty());
4387
4388 let (residual_value_sum, cached_value_sum) = evaluator
4389 .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4390 .expect("evaluation should succeed");
4391 assert_relative_eq!(
4392 cached_value_sum,
4393 0.0,
4394 epsilon = DETERMINISTIC_STRICT_ABS_TOL
4395 );
4396 assert_relative_eq!(
4397 residual_value_sum,
4398 evaluator
4399 .evaluate_weighted_value_sum_local(&fixture.parameters)
4400 .expect("evaluation should succeed"),
4401 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4402 max_relative = DETERMINISTIC_STRICT_REL_TOL
4403 );
4404
4405 let (residual_gradient_sum, cached_gradient_sum) = evaluator
4406 .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4407 .expect("evaluation should succeed");
4408 assert!(cached_gradient_sum
4409 .iter()
4410 .all(|value| value.abs() <= DETERMINISTIC_STRICT_ABS_TOL));
4411 let combined_gradient = evaluator
4412 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4413 .expect("evaluation should succeed");
4414 for (residual_item, combined_item) in
4415 residual_gradient_sum.iter().zip(combined_gradient.iter())
4416 {
4417 assert_relative_eq!(
4418 *residual_item,
4419 *combined_item,
4420 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4421 max_relative = DETERMINISTIC_STRICT_REL_TOL
4422 );
4423 }
4424 }
4425
4426 #[test]
4427 fn test_batch_evaluation() {
4428 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag")).unwrap();
4429 let mut event1 = test_event();
4430 event1.p4s[0].t = 10.0;
4431 let mut event2 = test_event();
4432 event2.p4s[0].t = 11.0;
4433 let mut event3 = test_event();
4434 event3.p4s[0].t = 12.0;
4435 let dataset = Arc::new(Dataset::new_with_metadata(
4436 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
4437 Arc::new(DatasetMetadata::default()),
4438 ));
4439 let evaluator = expr.load(&dataset).unwrap();
4440 let result = evaluator
4441 .evaluate_batch(&[1.1, 2.2], &[0, 2])
4442 .expect("evaluation should succeed");
4443 assert_eq!(result.len(), 2);
4444 assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
4445 assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
4446 let result_grad = evaluator
4447 .evaluate_gradient_batch(&[1.1, 2.2], &[0, 2])
4448 .expect("evaluation should succeed");
4449 assert_eq!(result_grad.len(), 2);
4450 assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
4451 assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
4452 assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
4453 assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
4454 }
4455
4456 #[test]
4457 fn test_load_compiles_expression_ir_once() {
4458 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4459 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4460 .norm_sqr();
4461 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4462 let evaluator = expr.load(&dataset).unwrap();
4463 assert!(evaluator.expression_slot_count() > 0);
4464 }
4465 #[test]
4466 fn test_expression_ir_value_matches_lowered_runtime() {
4467 let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4468 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4469 * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
4470 .conj()
4471 .norm_sqr();
4472 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4473 let evaluator = expr.load(&dataset).unwrap();
4474 let resources = evaluator.resources.read();
4475 let parameters = resources
4476 .parameter_map
4477 .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
4478 .expect("parameters should assemble");
4479 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4480 evaluator.fill_amplitude_values(
4481 &mut amplitude_values,
4482 resources.active_indices(),
4483 ¶meters,
4484 &resources.caches[0],
4485 );
4486 let mut ir_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4487 let lowered_runtime = evaluator.lowered_runtime();
4488 let lowered_program = lowered_runtime.value_program();
4489 let mut lowered_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4490 let lowered_value =
4491 evaluator.evaluate_expression_value_with_scratch(&litude_values, &mut ir_slots);
4492 let direct_lowered_value =
4493 lowered_program.evaluate_into(&litude_values, &mut lowered_slots);
4494 let ir_value = evaluator
4495 .expression_ir()
4496 .evaluate_into(&litude_values, &mut ir_slots);
4497 assert_relative_eq!(lowered_value.re, direct_lowered_value.re);
4498 assert_relative_eq!(lowered_value.im, direct_lowered_value.im);
4499 assert_relative_eq!(lowered_value.re, ir_value.re);
4500 assert_relative_eq!(lowered_value.im, ir_value.im);
4501 }
4502 #[test]
4503 fn test_expression_ir_load_initializes_with_lowered_value_runtime() {
4504 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai"))
4505 .unwrap()
4506 .norm_sqr();
4507 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4508 let evaluator = expr.load(&dataset).unwrap();
4509 let lowered_runtime = evaluator.lowered_runtime();
4510 assert_eq!(
4511 lowered_runtime.value_program().kind(),
4512 lowered::LoweredProgramKind::Value
4513 );
4514 assert_eq!(
4515 lowered_runtime.gradient_program().kind(),
4516 lowered::LoweredProgramKind::Gradient
4517 );
4518 assert_eq!(
4519 lowered_runtime.value_gradient_program().kind(),
4520 lowered::LoweredProgramKind::ValueGradient
4521 );
4522 }
4523 #[test]
4524 fn test_expression_ir_gradient_matches_lowered_runtime() {
4525 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4526 * TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4527 .norm_sqr();
4528 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4529 let evaluator = expr.load(&dataset).unwrap();
4530 let resources = evaluator.resources.read();
4531 let parameters = resources
4532 .parameter_map
4533 .assemble(&[1.0, 0.25, -0.8, 0.5])
4534 .expect("parameters should assemble");
4535 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4536 evaluator.fill_amplitude_values(
4537 &mut amplitude_values,
4538 resources.active_indices(),
4539 ¶meters,
4540 &resources.caches[0],
4541 );
4542 let mut active_mask = vec![false; evaluator.amplitudes.len()];
4543 for &index in resources.active_indices() {
4544 active_mask[index] = true;
4545 }
4546 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
4547 .map(|_| DVector::zeros(parameters.len()))
4548 .collect::<Vec<_>>();
4549 evaluator.fill_amplitude_gradients(
4550 &mut amplitude_gradients,
4551 &active_mask,
4552 ¶meters,
4553 &resources.caches[0],
4554 );
4555 let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4556 let mut ir_gradient_slots: Vec<DVector<Complex64>> =
4557 (0..evaluator.expression_ir().node_count())
4558 .map(|_| DVector::zeros(parameters.len()))
4559 .collect();
4560 let lowered_runtime = evaluator.lowered_runtime();
4561 let lowered_program = lowered_runtime.gradient_program();
4562 let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4563 let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
4564 .scratch_slots())
4565 .map(|_| DVector::zeros(parameters.len()))
4566 .collect();
4567 let active_gradient = evaluator.evaluate_expression_gradient_with_scratch(
4568 &litude_values,
4569 &litude_gradients,
4570 &mut ir_value_slots,
4571 &mut ir_gradient_slots,
4572 );
4573 let ir_gradient = evaluator.expression_ir().evaluate_gradient_into(
4574 &litude_values,
4575 &litude_gradients,
4576 &mut ir_value_slots,
4577 &mut ir_gradient_slots,
4578 );
4579 let lowered_gradient = lowered_program.evaluate_gradient_into(
4580 &litude_values,
4581 &litude_gradients,
4582 &mut lowered_value_slots,
4583 &mut lowered_gradient_slots,
4584 );
4585 for (active, lowered) in active_gradient.iter().zip(lowered_gradient.iter()) {
4586 assert_relative_eq!(active.re, lowered.re);
4587 assert_relative_eq!(active.im, lowered.im);
4588 }
4589 for (lowered, ir) in lowered_gradient.iter().zip(ir_gradient.iter()) {
4590 assert_relative_eq!(lowered.re, ir.re);
4591 assert_relative_eq!(lowered.im, ir.im);
4592 }
4593 }
4594 #[test]
4595 fn test_expression_ir_value_gradient_matches_lowered_runtime() {
4596 let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4597 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4598 * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
4599 .norm_sqr();
4600 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4601 let evaluator = expr.load(&dataset).unwrap();
4602 let resources = evaluator.resources.read();
4603 let parameters = resources
4604 .parameter_map
4605 .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
4606 .expect("parameters should assemble");
4607 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
4608 evaluator.fill_amplitude_values(
4609 &mut amplitude_values,
4610 resources.active_indices(),
4611 ¶meters,
4612 &resources.caches[0],
4613 );
4614 let mut active_mask = vec![false; evaluator.amplitudes.len()];
4615 for &index in resources.active_indices() {
4616 active_mask[index] = true;
4617 }
4618 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
4619 .map(|_| DVector::zeros(parameters.len()))
4620 .collect::<Vec<_>>();
4621 evaluator.fill_amplitude_gradients(
4622 &mut amplitude_gradients,
4623 &active_mask,
4624 ¶meters,
4625 &resources.caches[0],
4626 );
4627 let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
4628 let mut ir_gradient_slots: Vec<DVector<Complex64>> =
4629 (0..evaluator.expression_ir().node_count())
4630 .map(|_| DVector::zeros(parameters.len()))
4631 .collect();
4632 let lowered_runtime = evaluator.lowered_runtime();
4633 let lowered_program = lowered_runtime.value_gradient_program();
4634 let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
4635 let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
4636 .scratch_slots())
4637 .map(|_| DVector::zeros(parameters.len()))
4638 .collect();
4639
4640 let active_value_gradient = evaluator.evaluate_expression_value_gradient_with_scratch(
4641 &litude_values,
4642 &litude_gradients,
4643 &mut ir_value_slots,
4644 &mut ir_gradient_slots,
4645 );
4646 let ir_value_gradient = evaluator.expression_ir().evaluate_value_gradient_into(
4647 &litude_values,
4648 &litude_gradients,
4649 &mut ir_value_slots,
4650 &mut ir_gradient_slots,
4651 );
4652 let lowered_value_gradient = lowered_program.evaluate_value_gradient_into(
4653 &litude_values,
4654 &litude_gradients,
4655 &mut lowered_value_slots,
4656 &mut lowered_gradient_slots,
4657 );
4658
4659 assert_relative_eq!(active_value_gradient.0.re, lowered_value_gradient.0.re);
4660 assert_relative_eq!(active_value_gradient.0.im, lowered_value_gradient.0.im);
4661 for (active, lowered) in active_value_gradient
4662 .1
4663 .iter()
4664 .zip(lowered_value_gradient.1.iter())
4665 {
4666 assert_relative_eq!(active.re, lowered.re);
4667 assert_relative_eq!(active.im, lowered.im);
4668 }
4669 assert_relative_eq!(lowered_value_gradient.0.re, ir_value_gradient.0.re);
4670 assert_relative_eq!(lowered_value_gradient.0.im, ir_value_gradient.0.im);
4671 for (lowered, ir) in lowered_value_gradient
4672 .1
4673 .iter()
4674 .zip(ir_value_gradient.1.iter())
4675 {
4676 assert_relative_eq!(lowered.re, ir.re);
4677 assert_relative_eq!(lowered.im, ir.im);
4678 }
4679 }
4680 #[test]
4681 fn test_expression_runtime_diagnostics_reports_lowered_programs() {
4682 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4683 let evaluator = fixture
4684 .expression
4685 .load(&fixture.dataset)
4686 .expect("fixture evaluator should load");
4687
4688 let diagnostics = evaluator.expression_runtime_diagnostics();
4689 assert!(diagnostics.ir_planning_enabled);
4690 assert!(diagnostics.lowered_value_program_present);
4691 assert!(diagnostics.lowered_gradient_program_present);
4692 assert!(diagnostics.lowered_value_gradient_program_present);
4693 assert!(diagnostics.residual_runtime_present);
4694 assert_eq!(
4695 diagnostics.specialization_status,
4696 Some(ExpressionSpecializationStatus {
4697 origin: ExpressionSpecializationOrigin::InitialLoad,
4698 })
4699 );
4700 }
4701 #[test]
4702 fn test_expression_runtime_diagnostics_reports_specialization_origin() {
4703 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4704 let evaluator = fixture
4705 .expression
4706 .load(&fixture.dataset)
4707 .expect("fixture evaluator should load");
4708
4709 assert_eq!(
4710 evaluator
4711 .expression_runtime_diagnostics()
4712 .specialization_status,
4713 Some(ExpressionSpecializationStatus {
4714 origin: ExpressionSpecializationOrigin::InitialLoad,
4715 })
4716 );
4717
4718 evaluator.isolate_many(&["p"]);
4719 assert_eq!(
4720 evaluator
4721 .expression_runtime_diagnostics()
4722 .specialization_status,
4723 Some(ExpressionSpecializationStatus {
4724 origin: ExpressionSpecializationOrigin::CacheMissRebuild,
4725 })
4726 );
4727 }
4728 #[test]
4729 fn test_compiled_expression_display_reports_dag_refs() {
4730 let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4731 let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4732 let term = &a * &b;
4733 let expr = &term + &term;
4734 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4735 let evaluator = expr.load(&dataset).unwrap();
4736
4737 let compiled = evaluator.compiled_expression();
4738 let display = compiled.to_string();
4739
4740 assert_eq!(compiled.root(), compiled.nodes().len() - 1);
4741 assert!(display.contains("#"));
4742 assert!(display.contains("+"));
4743 assert!(display.contains("×"));
4744 assert!(display.contains("a(id=0)"));
4745 assert!(display.contains("b(id=1)"));
4746 assert!(display.contains("(ref)"));
4747 }
4748
4749 #[test]
4750 fn test_expression_compiled_expression_display_reports_dag_refs_without_loading() {
4751 let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4752 let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4753 let term = &a * &b;
4754 let expr = &term + &term;
4755
4756 let compiled = expr.compiled_expression();
4757 let display = compiled.to_string();
4758
4759 assert_eq!(compiled.root(), compiled.nodes().len() - 1);
4760 assert!(display.contains("#"));
4761 assert!(display.contains("+"));
4762 assert!(display.contains("×"));
4763 assert!(display.contains("a(id=0)"));
4764 assert!(display.contains("b(id=1)"));
4765 assert!(display.contains("(ref)"));
4766 }
4767
4768 #[test]
4769 fn test_compiled_expression_display_uses_current_active_mask() {
4770 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4771 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
4772 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4773 let evaluator = expr.load(&dataset).unwrap();
4774 evaluator.deactivate("b");
4775
4776 let compiled = evaluator.compiled_expression().to_string();
4777
4778 assert!(compiled.contains("a(id=0)"));
4779 assert!(!compiled.contains("b(id=1)"));
4780 assert!(!compiled.contains("const 0"));
4781 assert!(!compiled.contains("+"));
4782 }
4783
4784 fn assert_compiled_single_amplitude(expr: &Expression, expected_label: &str) {
4785 let compiled = expr.compiled_expression();
4786 assert_eq!(compiled.nodes().len(), 1);
4787 assert_eq!(compiled.root(), 0);
4788 match &compiled.nodes()[0] {
4789 CompiledExpressionNode::Amplitude { index, name } => {
4790 assert_eq!(*index, 0);
4791 assert_eq!(name, expected_label);
4792 }
4793 node => panic!("expected one amplitude node, got {node:?}"),
4794 }
4795 }
4796
4797 fn assert_compiled_constant(expr: &Expression, expected: Complex64) {
4798 let compiled = expr.compiled_expression();
4799 assert_eq!(compiled.nodes().len(), 1);
4800 assert_eq!(compiled.root(), 0);
4801 match compiled.nodes()[0] {
4802 CompiledExpressionNode::Constant(value) => {
4803 assert_relative_eq!(value.re, expected.re);
4804 assert_relative_eq!(value.im, expected.im);
4805 }
4806 ref node => panic!("expected one constant node, got {node:?}"),
4807 }
4808 }
4809
4810 #[test]
4811 fn test_compiled_expression_simplifies_arithmetic_identities() {
4812 let amp = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4813 let zero = Expression::zero();
4814 let one = Expression::one();
4815
4816 assert_compiled_single_amplitude(&(& + &zero), "a");
4817 assert_compiled_single_amplitude(&(&zero + &), "a");
4818 assert_compiled_single_amplitude(&(& - &zero), "a");
4819 assert_compiled_single_amplitude(&(& * &one), "a");
4820 assert_compiled_single_amplitude(&(&one * &), "a");
4821 assert_compiled_single_amplitude(&(& / &one), "a");
4822 assert_compiled_single_amplitude(&.pow(&one), "a");
4823 assert_compiled_single_amplitude(&.powi(1), "a");
4824 assert_compiled_single_amplitude(&.powf(1.0), "a");
4825
4826 let times_zero = & * &zero;
4827 assert_compiled_constant(×_zero, Complex64::ZERO);
4828 assert!(times_zero.parameters().contains_key("ar"));
4829 assert!(times_zero.parameters().contains_key("ai"));
4830
4831 assert_compiled_constant(&(&zero * &), Complex64::ZERO);
4832 assert_compiled_constant(&(&zero / &Expression::from(2.0)), Complex64::ZERO);
4833 assert_compiled_constant(&.powi(0), Complex64::ONE);
4834 assert_compiled_constant(
4835 &Expression::from(2.0).pow(&Expression::zero()),
4836 Complex64::ONE,
4837 );
4838 assert_compiled_constant(&Expression::from(2.0).powf(0.0), Complex64::ONE);
4839
4840 let unsafe_zero_division = (&zero / &).compiled_expression().to_string();
4841 assert!(unsafe_zero_division.contains("÷"));
4842 assert!(unsafe_zero_division.contains("a(id=0)"));
4843 }
4844
4845 #[test]
4846 fn test_compiled_expression_folds_unary_constant_functions() {
4847 assert_compiled_constant(&Expression::from(0.0).exp(), Complex64::ONE);
4848 assert_compiled_constant(&Expression::from(0.0).sin(), Complex64::ZERO);
4849 assert_compiled_constant(&Expression::from(0.0).cos(), Complex64::ONE);
4850 assert_compiled_constant(&Expression::from(1.0).log(), Complex64::ZERO);
4851 assert_compiled_constant(&Expression::from(4.0).sqrt(), Complex64::new(2.0, 0.0));
4852 assert_compiled_constant(&Expression::from(0.0).cis(), Complex64::ONE);
4853 }
4854
4855 #[test]
4856 fn test_evaluator_expression_reconstructs_expression() {
4857 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
4858 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4859 let evaluator = expr.load(&dataset).unwrap();
4860
4861 assert_eq!(
4862 evaluator.expression().compiled_expression(),
4863 expr.compiled_expression()
4864 );
4865 }
4866
4867 #[test]
4868 fn test_active_mask_override_ignores_current_ir_specialization() {
4869 let expr = ComplexScalar::new("amp", parameter!("scale"), parameter!("amp_im", 0.0))
4870 .unwrap()
4871 .norm_sqr();
4872 let dataset = Arc::new(test_dataset());
4873 let evaluator = expr.load(&dataset).unwrap();
4874 let params = vec![2.0];
4875
4876 evaluator.deactivate("amp");
4877 assert_eq!(
4878 evaluator
4879 .evaluate(¶ms)
4880 .expect("evaluation should succeed")[0],
4881 Complex64::new(0.0, 0.0)
4882 );
4883
4884 let overridden = evaluator
4885 .evaluate_local_with_active_mask(¶ms, &[true])
4886 .unwrap();
4887 assert_eq!(overridden[0], Complex64::new(4.0, 0.0));
4888
4889 let overridden_fused = evaluator
4890 .evaluate_with_gradient_local_with_active_mask(¶ms, &[true])
4891 .unwrap();
4892 assert_eq!(overridden_fused[0].0, Complex64::new(4.0, 0.0));
4893 assert_eq!(overridden_fused[0].1[0], Complex64::new(4.0, 0.0));
4894 }
4895 #[test]
4896 fn test_expression_ir_dependence_diagnostics_surface() {
4897 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
4898 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
4899 .norm_sqr();
4900 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4901 let evaluator = expr.load(&dataset).unwrap();
4902 let annotations = evaluator
4903 .expression_node_dependence_annotations()
4904 .expect("annotations should exist");
4905 assert_eq!(annotations.len(), evaluator.expression_ir().node_count());
4906 assert!(annotations
4907 .iter()
4908 .all(|dependence| *dependence == ExpressionDependence::Mixed));
4909 assert_eq!(
4910 evaluator
4911 .expression_root_dependence()
4912 .expect("root dependence should exist"),
4913 ExpressionDependence::Mixed
4914 );
4915 }
4916 #[test]
4917 fn test_expression_ir_default_dependence_hint_is_mixed() {
4918 let expr = ComplexScalar::new("c", parameter!("cr"), parameter!("ci")).unwrap();
4919 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4920 let evaluator = expr.load(&dataset).unwrap();
4921 assert_eq!(
4922 evaluator
4923 .expression_root_dependence()
4924 .expect("root dependence should exist"),
4925 ExpressionDependence::Mixed
4926 );
4927 }
4928 #[test]
4929 fn test_expression_ir_parameter_only_dependence_hint_propagates() {
4930 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap();
4931 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4932 let evaluator = expr.load(&dataset).unwrap();
4933 assert_eq!(
4934 evaluator
4935 .expression_root_dependence()
4936 .expect("root dependence should exist"),
4937 ExpressionDependence::ParameterOnly
4938 );
4939 }
4940 #[test]
4941 fn test_expression_ir_cache_only_dependence_hint_propagates() {
4942 let expr = CacheOnlyScalar::new("k").unwrap();
4943 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4944 let evaluator = expr.load(&dataset).unwrap();
4945 assert_eq!(
4946 evaluator
4947 .expression_root_dependence()
4948 .expect("root dependence should exist"),
4949 ExpressionDependence::CacheOnly
4950 );
4951 }
4952 #[test]
4953 fn test_expression_ir_real_valued_hint_folds_imag_projection_to_zero() {
4954 let expr = ParameterOnlyScalar::new("p", parameter!("p"))
4955 .unwrap()
4956 .imag();
4957 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4958 let evaluator = expr.load(&dataset).unwrap();
4959 let ir = evaluator.expression_ir();
4960
4961 assert!(matches!(
4962 ir.nodes()[ir.root()],
4963 ir::IrNode::Constant(value) if value == Complex64::ZERO
4964 ));
4965 assert_eq!(
4966 evaluator
4967 .evaluate(&[2.5])
4968 .expect("evaluation should succeed")[0],
4969 Complex64::ZERO
4970 );
4971 }
4972 #[test]
4973 fn test_expression_ir_real_valued_hint_simplifies_conjugation() {
4974 let expr = ParameterOnlyScalar::new("p", parameter!("p"))
4975 .unwrap()
4976 .conj();
4977 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4978 let evaluator = expr.load(&dataset).unwrap();
4979 let ir = evaluator.expression_ir();
4980
4981 assert!(matches!(ir.nodes()[ir.root()], ir::IrNode::Amp(0)));
4982 assert_eq!(
4983 evaluator
4984 .evaluate(&[2.5])
4985 .expect("evaluation should succeed")[0],
4986 Complex64::new(2.5, 0.0)
4987 );
4988 }
4989 #[test]
4990 fn test_expression_ir_dependence_warnings_surface() {
4991 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
4992 + &CacheOnlyScalar::new("k").unwrap();
4993 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
4994 let evaluator = expr.load(&dataset).unwrap();
4995 assert!(evaluator
4996 .expression_dependence_warnings()
4997 .expect("warnings should exist")
4998 .iter()
4999 .any(|warning| warning.contains("both ParameterOnly and CacheOnly")));
5000 }
5001 #[test]
5002 fn test_expression_ir_normalization_plan_explain_surface() {
5003 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5004 * &CacheOnlyScalar::new("k").unwrap();
5005 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5006 let evaluator = expr.load(&dataset).unwrap();
5007 let explain = evaluator
5008 .expression_normalization_plan_explain()
5009 .expect("plan should exist");
5010 assert_eq!(explain.root_dependence, ExpressionDependence::Mixed);
5011 assert_eq!(explain.separable_mul_candidate_nodes.len(), 1);
5012 assert_eq!(
5013 explain.cached_separable_nodes,
5014 explain.separable_mul_candidate_nodes
5015 );
5016 assert!(explain.residual_terms.iter().all(|index| {
5017 !explain
5018 .separable_mul_candidate_nodes
5019 .iter()
5020 .any(|candidate| candidate == index)
5021 }));
5022 }
5023 #[test]
5024 fn test_expression_ir_normalization_execution_sets_surface() {
5025 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5026 * &CacheOnlyScalar::new("k").unwrap();
5027 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5028 let evaluator = expr.load(&dataset).unwrap();
5029 let sets = evaluator
5030 .expression_normalization_execution_sets()
5031 .expect("sets should exist");
5032 assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5033 assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5034 assert!(sets.residual_amplitudes.is_empty());
5035 }
5036 #[test]
5037 fn test_expression_ir_normalization_execution_sets_partial_surface() {
5038 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5039 * &CacheOnlyScalar::new("k").unwrap())
5040 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5041 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5042 let evaluator = expr.load(&dataset).unwrap();
5043 let sets = evaluator
5044 .expression_normalization_execution_sets()
5045 .expect("sets should exist");
5046 assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5047 assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5048 assert_eq!(sets.residual_amplitudes, vec![2]);
5049 }
5050 #[test]
5051 fn test_expression_ir_precomputed_cached_integrals_at_load() {
5052 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5053 * &CacheOnlyScalar::new("k").unwrap();
5054 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5055 let evaluator = expr.load(&dataset).unwrap();
5056 let precomputed = evaluator
5057 .expression_precomputed_cached_integrals()
5058 .expect("integrals should exist");
5059 assert_eq!(precomputed.len(), 1);
5060 let cache_reference = CacheOnlyScalar::new("k_ref")
5061 .unwrap()
5062 .load(&dataset)
5063 .unwrap();
5064 let cache_values = cache_reference
5065 .evaluate_local(&[])
5066 .expect("evaluation should succeed");
5067 let expected_weighted_sum = cache_values
5068 .iter()
5069 .zip(dataset.weights_local().iter())
5070 .fold(Complex64::ZERO, |acc, (value, event)| {
5071 acc + (*value * *event)
5072 });
5073 assert_relative_eq!(
5074 precomputed[0].weighted_cache_sum.re,
5075 expected_weighted_sum.re
5076 );
5077 assert_relative_eq!(
5078 precomputed[0].weighted_cache_sum.im,
5079 expected_weighted_sum.im
5080 );
5081 }
5082 #[test]
5083 fn test_expression_ir_precomputed_cached_integrals_empty_when_non_separable() {
5084 let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5085 * &CacheOnlyScalar::new("k").unwrap();
5086 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5087 let evaluator = expr.load(&dataset).unwrap();
5088 assert!(evaluator
5089 .expression_precomputed_cached_integrals()
5090 .expect("integrals should exist")
5091 .is_empty());
5092 }
5093 #[test]
5094 fn test_expression_ir_precomputed_cached_integrals_recompute_on_activation_change() {
5095 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5096 * &CacheOnlyScalar::new("k").unwrap();
5097 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5098 let evaluator = expr.load(&dataset).unwrap();
5099 assert_eq!(
5100 evaluator
5101 .expression_precomputed_cached_integrals()
5102 .expect("integrals should exist")
5103 .len(),
5104 1
5105 );
5106
5107 evaluator.isolate_many(&["p"]);
5108 assert!(evaluator
5109 .expression_precomputed_cached_integrals()
5110 .expect("integrals should exist")
5111 .is_empty());
5112 }
5113 #[test]
5114 fn test_expression_ir_precomputed_cached_integrals_recompute_on_dataset_change() {
5115 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5116 * &CacheOnlyScalar::new("k").unwrap();
5117 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5118 let mut evaluator = expr.load(&dataset).unwrap();
5119 drop(dataset);
5120 let before = evaluator
5121 .expression_precomputed_cached_integrals()
5122 .expect("integrals should exist");
5123 assert_eq!(before.len(), 1);
5124
5125 Arc::get_mut(&mut evaluator.dataset)
5126 .expect("evaluator should own dataset Arc in this test")
5127 .clear_events_local();
5128 let after = evaluator
5129 .expression_precomputed_cached_integrals()
5130 .expect("integrals should exist");
5131 assert_eq!(after.len(), 1);
5132 assert_eq!(after[0].weighted_cache_sum, Complex64::ZERO);
5133 assert!(before[0].weighted_cache_sum != after[0].weighted_cache_sum);
5134 }
5135 #[test]
5136 fn test_expression_ir_precomputed_cached_integral_gradient_terms_scale_by_cache_integrals() {
5137 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5138 * &CacheOnlyScalar::new("k").unwrap();
5139 let dataset = Arc::new(Dataset::new(vec![
5140 Arc::new(test_event()),
5141 Arc::new(test_event()),
5142 ]));
5143 let evaluator = expr.load(&dataset).unwrap();
5144 let cached_integrals = evaluator
5145 .expression_precomputed_cached_integrals()
5146 .expect("integrals should exist");
5147 assert_eq!(cached_integrals.len(), 1);
5148 let gradient_terms = evaluator
5149 .expression_precomputed_cached_integral_gradient_terms(&[1.25])
5150 .expect("evaluation should succeed");
5151 assert_eq!(gradient_terms.len(), 1);
5152 assert_eq!(gradient_terms[0].weighted_gradient.len(), 1);
5153 assert_relative_eq!(
5154 gradient_terms[0].weighted_gradient[0].re,
5155 cached_integrals[0].weighted_cache_sum.re,
5156 epsilon = 1e-6
5157 );
5158 assert_relative_eq!(
5159 gradient_terms[0].weighted_gradient[0].im,
5160 cached_integrals[0].weighted_cache_sum.im,
5161 epsilon = 1e-6
5162 );
5163 }
5164 #[test]
5165 fn test_expression_ir_precomputed_cached_integral_gradient_terms_empty_when_not_separable() {
5166 let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5167 * &CacheOnlyScalar::new("k").unwrap();
5168 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5169 let evaluator = expr.load(&dataset).unwrap();
5170 assert!(evaluator
5171 .expression_precomputed_cached_integral_gradient_terms(&[0.1, -0.2])
5172 .expect("evaluation should succeed")
5173 .is_empty());
5174 }
5175 #[test]
5176 fn test_expression_ir_lowered_cached_factor_programs_match_ir_cached_paths() {
5177 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5178 * &CacheOnlyScalar::new("k").unwrap())
5179 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5180 let dataset = Arc::new(test_dataset());
5181 let evaluator = expr.load(&dataset).unwrap();
5182 let resources = evaluator.resources.read();
5183 let state = evaluator
5184 .ensure_cached_integral_cache_state(&resources)
5185 .expect("state should be available");
5186 let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5187 let parameters = resources
5188 .parameter_map
5189 .assemble(&[0.55, 0.2, -0.15])
5190 .expect("parameters should assemble");
5191
5192 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5193 evaluator.fill_amplitude_values(
5194 &mut amplitude_values,
5195 &state.execution_sets.cached_parameter_amplitudes,
5196 ¶meters,
5197 &resources.caches[0],
5198 );
5199 let cached_value_ir =
5200 evaluator.evaluate_cached_weighted_value_sum_ir(&state, &litude_values);
5201 let cached_value_lowered = evaluator
5202 .evaluate_cached_weighted_value_sum_lowered(
5203 &state,
5204 lowered_artifacts.as_ref(),
5205 &litude_values,
5206 )
5207 .expect("cached value lowering should succeed");
5208 assert_relative_eq!(cached_value_lowered, cached_value_ir, epsilon = 1e-12);
5209
5210 let mut cached_parameter_mask = vec![false; evaluator.amplitudes.len()];
5211 for &index in &state.execution_sets.cached_parameter_amplitudes {
5212 cached_parameter_mask[index] = true;
5213 }
5214 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5215 .map(|_| DVector::zeros(parameters.len()))
5216 .collect::<Vec<_>>();
5217 evaluator.fill_amplitude_gradients(
5218 &mut amplitude_gradients,
5219 &cached_parameter_mask,
5220 ¶meters,
5221 &resources.caches[0],
5222 );
5223 let cached_gradient_ir = evaluator.evaluate_cached_weighted_gradient_sum_ir(
5224 &state,
5225 &litude_values,
5226 &litude_gradients,
5227 parameters.len(),
5228 );
5229 let cached_gradient_lowered = evaluator
5230 .evaluate_cached_weighted_gradient_sum_lowered(
5231 &state,
5232 lowered_artifacts.as_ref(),
5233 &litude_values,
5234 &litude_gradients,
5235 parameters.len(),
5236 )
5237 .expect("cached gradient lowering should succeed");
5238 for (lowered, ir) in cached_gradient_lowered
5239 .iter()
5240 .zip(cached_gradient_ir.iter())
5241 {
5242 assert_relative_eq!(*lowered, *ir, epsilon = 1e-12);
5243 }
5244 }
5245 #[test]
5246 fn test_expression_ir_lowered_residual_runtime_matches_zeroed_node_path() {
5247 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5248 * &CacheOnlyScalar::new("k").unwrap())
5249 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5250 let dataset = Arc::new(test_dataset());
5251 let evaluator = expr.load(&dataset).unwrap();
5252 let resources = evaluator.resources.read();
5253 let state = evaluator
5254 .ensure_cached_integral_cache_state(&resources)
5255 .expect("state should be available");
5256 let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5257 let parameters = resources
5258 .parameter_map
5259 .assemble(&[0.55, 0.2, -0.15])
5260 .expect("parameters should assemble");
5261
5262 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5263 evaluator.fill_amplitude_values(
5264 &mut amplitude_values,
5265 &state.execution_sets.residual_amplitudes,
5266 ¶meters,
5267 &resources.caches[0],
5268 );
5269 let residual_value_ir = evaluator.evaluate_residual_value_ir(&state, &litude_values);
5270 let residual_program = lowered_artifacts
5271 .residual_runtime
5272 .as_ref()
5273 .map(|runtime| runtime.value_program())
5274 .expect("residual value lowering should succeed");
5275 let mut value_slots = vec![Complex64::ZERO; residual_program.scratch_slots()];
5276 let residual_value_lowered =
5277 residual_program.evaluate_into(&litude_values, &mut value_slots);
5278 assert_relative_eq!(
5279 residual_value_lowered.re,
5280 residual_value_ir.re,
5281 epsilon = 1e-12
5282 );
5283 assert_relative_eq!(
5284 residual_value_lowered.im,
5285 residual_value_ir.im,
5286 epsilon = 1e-12
5287 );
5288
5289 let mut residual_active_mask = vec![false; evaluator.amplitudes.len()];
5290 for &index in &state.execution_sets.residual_amplitudes {
5291 residual_active_mask[index] = true;
5292 }
5293 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5294 .map(|_| DVector::zeros(parameters.len()))
5295 .collect::<Vec<_>>();
5296 evaluator.fill_amplitude_gradients(
5297 &mut amplitude_gradients,
5298 &residual_active_mask,
5299 ¶meters,
5300 &resources.caches[0],
5301 );
5302 let residual_gradient_ir = evaluator.evaluate_residual_gradient_ir(
5303 &state,
5304 &litude_values,
5305 &litude_gradients,
5306 parameters.len(),
5307 );
5308
5309 let program = lowered_artifacts
5310 .residual_runtime
5311 .as_ref()
5312 .map(|runtime| runtime.gradient_program())
5313 .expect("gradient lowering should succeed");
5314 let mut value_slots = vec![Complex64::ZERO; program.scratch_slots()];
5315 let mut gradient_slots = vec![Complex64::ZERO; program.scratch_slots() * parameters.len()];
5316 let residual_gradient_lowered = program.evaluate_gradient_into_flat(
5317 &litude_values,
5318 &litude_gradients,
5319 &mut value_slots,
5320 &mut gradient_slots,
5321 parameters.len(),
5322 );
5323
5324 for (lowered, ir) in residual_gradient_lowered
5325 .iter()
5326 .zip(residual_gradient_ir.iter())
5327 {
5328 assert_relative_eq!(lowered.re, ir.re, epsilon = 1e-12);
5329 assert_relative_eq!(lowered.im, ir.im, epsilon = 1e-12);
5330 }
5331 }
5332 #[test]
5333 fn test_expression_ir_reuses_lowered_artifacts_when_dataset_key_changes() {
5334 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5335 * &CacheOnlyScalar::new("k").unwrap())
5336 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5337 let dataset = Arc::new(test_dataset());
5338 let mut evaluator = expr.load(&dataset).unwrap();
5339 drop(dataset);
5340
5341 assert_eq!(evaluator.specialization_cache_len(), 1);
5342 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5343
5344 evaluator.reset_expression_compile_metrics();
5345 evaluator.reset_expression_specialization_metrics();
5346
5347 Arc::get_mut(&mut evaluator.dataset)
5348 .expect("evaluator should own dataset Arc in this test")
5349 .clear_events_local();
5350
5351 let cached_integrals = evaluator
5352 .expression_precomputed_cached_integrals()
5353 .expect("integrals should exist");
5354 assert_eq!(cached_integrals.len(), 1);
5355 assert_eq!(cached_integrals[0].weighted_cache_sum, Complex64::ZERO);
5356
5357 assert_eq!(evaluator.specialization_cache_len(), 2);
5358 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5359 assert_eq!(
5360 evaluator.expression_specialization_metrics(),
5361 ExpressionSpecializationMetrics {
5362 cache_hits: 0,
5363 cache_misses: 1,
5364 }
5365 );
5366
5367 let compile_metrics = evaluator.expression_compile_metrics();
5368 assert_eq!(compile_metrics.specialization_cache_hits, 0);
5369 assert_eq!(compile_metrics.specialization_cache_misses, 1);
5370 assert_eq!(compile_metrics.specialization_lowering_cache_hits, 1);
5371 assert_eq!(compile_metrics.specialization_lowering_cache_misses, 0);
5372 assert!(compile_metrics.specialization_ir_compile_nanos > 0);
5373 assert!(compile_metrics.specialization_cached_integrals_nanos > 0);
5374 assert_eq!(compile_metrics.specialization_lowering_nanos, 0);
5375 }
5376
5377 #[test]
5378 fn test_evaluate_weighted_gradient_sum_local_matches_eventwise_baseline() {
5379 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5380 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5381 let c1 = CacheOnlyScalar::new("c1").unwrap();
5382 let c2 = CacheOnlyScalar::new("c2").unwrap();
5383 let c3 = CacheOnlyScalar::new("c3").unwrap();
5384 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5385 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5386 let dataset = Arc::new(test_dataset());
5387 let evaluator = expr.load(&dataset).unwrap();
5388 assert_eq!(
5389 evaluator
5390 .expression_precomputed_cached_integrals()
5391 .expect("integrals should exist")
5392 .len(),
5393 2
5394 );
5395 let params = vec![0.2, -0.3, 1.1, -0.7];
5396 let expected = evaluator
5397 .evaluate_gradient_local(¶ms)
5398 .expect("evaluation should succeed")
5399 .iter()
5400 .zip(dataset.weights_local().iter())
5401 .fold(
5402 DVector::zeros(params.len()),
5403 |mut accum, (gradient, event)| {
5404 accum += gradient.map(|value| value.re).scale(*event);
5405 accum
5406 },
5407 );
5408 let actual = evaluator
5409 .evaluate_weighted_gradient_sum_local(¶ms)
5410 .expect("evaluation should succeed");
5411 for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
5412 assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
5413 }
5414 }
5415
5416 #[test]
5417 fn test_evaluate_weighted_value_sum_local_matches_eventwise_baseline() {
5418 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5419 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5420 let c1 = CacheOnlyScalar::new("c1").unwrap();
5421 let c2 = CacheOnlyScalar::new("c2").unwrap();
5422 let c3 = CacheOnlyScalar::new("c3").unwrap();
5423 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5424 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5425 let dataset = Arc::new(test_dataset());
5426 let evaluator = expr.load(&dataset).unwrap();
5427 assert_eq!(
5428 evaluator
5429 .expression_precomputed_cached_integrals()
5430 .expect("integrals should exist")
5431 .len(),
5432 2
5433 );
5434 let params = vec![0.2, -0.3, 1.1, -0.7];
5435 let expected = evaluator
5436 .evaluate_local(¶ms)
5437 .expect("evaluation should succeed")
5438 .iter()
5439 .zip(dataset.weights_local().iter())
5440 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5441 let actual = evaluator
5442 .evaluate_weighted_value_sum_local(¶ms)
5443 .expect("evaluation should succeed");
5444 assert_relative_eq!(actual, expected, epsilon = 1e-10);
5445 }
5446
5447 #[test]
5448 fn test_weighted_sums_match_hardcoded_reference_values() {
5449 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5450 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5451 let c1 = CacheOnlyScalar::new("c1").unwrap();
5452 let c2 = CacheOnlyScalar::new("c2").unwrap();
5453 let c3 = CacheOnlyScalar::new("c3").unwrap();
5454 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5455 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5456
5457 let metadata = Arc::new(DatasetMetadata::default());
5458 let dataset = Arc::new(Dataset::new_with_metadata(
5459 vec![
5460 Arc::new(EventData {
5461 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5462 aux: vec![],
5463 weight: 0.5,
5464 }),
5465 Arc::new(EventData {
5466 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5467 aux: vec![],
5468 weight: -1.25,
5469 }),
5470 Arc::new(EventData {
5471 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5472 aux: vec![],
5473 weight: 2.0,
5474 }),
5475 ],
5476 metadata,
5477 ));
5478 let evaluator = expr.load(&dataset).unwrap();
5479 let params = vec![0.7, -1.1, 0.9, -0.4];
5480
5481 let weighted_value_sum = evaluator
5482 .evaluate_weighted_value_sum_local(¶ms)
5483 .expect("evaluation should succeed");
5484 assert_relative_eq!(weighted_value_sum, 22.7725, epsilon = 1e-12);
5485
5486 let weighted_gradient_sum = evaluator
5487 .evaluate_weighted_gradient_sum_local(¶ms)
5488 .expect("evaluation should succeed");
5489 let free_parameters = evaluator
5490 .parameters()
5491 .free()
5492 .names()
5493 .into_iter()
5494 .map(|name| name.to_string())
5495 .collect::<Vec<_>>();
5496 assert_eq!(free_parameters, vec!["p1", "p2", "m1r", "m1i"]);
5497 let expected_gradient = [43.925, 7.25, 28.525, 0.0];
5498 assert_eq!(weighted_gradient_sum.len(), expected_gradient.len());
5499 for (actual, expected) in weighted_gradient_sum.iter().zip(expected_gradient.iter()) {
5500 assert_relative_eq!(*actual, *expected, epsilon = 1e-9);
5501 }
5502 }
5503 #[test]
5504 fn test_evaluate_weighted_gradient_sum_local_respects_signed_cached_terms() {
5505 let expr = Expression::one()
5506 - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5507 * &CacheOnlyScalar::new("k").unwrap());
5508 let dataset = Arc::new(test_dataset());
5509 let evaluator = expr.load(&dataset).unwrap();
5510 assert_eq!(
5511 evaluator
5512 .expression_precomputed_cached_integrals()
5513 .expect("integrals should exist")
5514 .len(),
5515 1
5516 );
5517 assert_eq!(
5518 evaluator
5519 .expression_precomputed_cached_integrals()
5520 .expect("integrals should exist")[0]
5521 .coefficient,
5522 -1
5523 );
5524 let params = vec![0.75];
5525 let expected = evaluator
5526 .evaluate_gradient_local(¶ms)
5527 .expect("evaluation should succeed")
5528 .iter()
5529 .zip(dataset.weights_local().iter())
5530 .fold(
5531 DVector::zeros(params.len()),
5532 |mut accum, (gradient, event)| {
5533 accum += gradient.map(|value| value.re).scale(*event);
5534 accum
5535 },
5536 );
5537 let actual = evaluator
5538 .evaluate_weighted_gradient_sum_local(¶ms)
5539 .expect("evaluation should succeed");
5540 for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
5541 assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
5542 }
5543 }
5544 #[test]
5545 fn test_evaluate_weighted_value_sum_local_respects_signed_cached_terms() {
5546 let expr = Expression::one()
5547 - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5548 * &CacheOnlyScalar::new("k").unwrap());
5549 let dataset = Arc::new(test_dataset());
5550 let evaluator = expr.load(&dataset).unwrap();
5551 assert_eq!(
5552 evaluator
5553 .expression_precomputed_cached_integrals()
5554 .expect("integrals should exist")
5555 .len(),
5556 1
5557 );
5558 assert_eq!(
5559 evaluator
5560 .expression_precomputed_cached_integrals()
5561 .expect("integrals should exist")[0]
5562 .coefficient,
5563 -1
5564 );
5565 let params = vec![0.75];
5566 let expected = evaluator
5567 .evaluate_local(¶ms)
5568 .expect("evaluation should succeed")
5569 .iter()
5570 .zip(dataset.weights_local().iter())
5571 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5572 let actual = evaluator
5573 .evaluate_weighted_value_sum_local(¶ms)
5574 .expect("evaluation should succeed");
5575 assert_relative_eq!(actual, expected, epsilon = 1e-10);
5576 }
5577 #[test]
5578 fn test_expression_ir_diagnostics_follow_activation_changes() {
5579 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5580 * &CacheOnlyScalar::new("k").unwrap())
5581 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5582 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5583 let evaluator = expr.load(&dataset).unwrap();
5584
5585 let all_active = evaluator
5586 .expression_normalization_plan_explain()
5587 .expect("plan should exist");
5588 assert_eq!(all_active.cached_separable_nodes.len(), 1);
5589 assert_eq!(
5590 evaluator
5591 .expression_root_dependence()
5592 .expect("root dependence should exist"),
5593 ExpressionDependence::Mixed
5594 );
5595
5596 evaluator.isolate_many(&["p"]);
5597 let param_only = evaluator
5598 .expression_normalization_plan_explain()
5599 .expect("plan should exist");
5600 assert!(param_only.cached_separable_nodes.is_empty());
5601 assert_eq!(
5602 evaluator
5603 .expression_root_dependence()
5604 .expect("root dependence should exist"),
5605 ExpressionDependence::ParameterOnly
5606 );
5607 }
5608 #[test]
5609 fn test_expression_ir_specialization_cache_reuses_prior_mask_specializations() {
5610 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5611 * &CacheOnlyScalar::new("k").unwrap())
5612 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5613 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5614 let evaluator = expr.load(&dataset).unwrap();
5615
5616 let initial_compile_metrics = evaluator.expression_compile_metrics();
5617 assert!(initial_compile_metrics.initial_ir_compile_nanos > 0);
5618 assert!(initial_compile_metrics.initial_cached_integrals_nanos > 0);
5619 assert!(initial_compile_metrics.initial_lowering_nanos > 0);
5620 assert_eq!(initial_compile_metrics.specialization_cache_hits, 0);
5621 assert_eq!(initial_compile_metrics.specialization_cache_misses, 0);
5622 assert_eq!(
5623 initial_compile_metrics.specialization_lowering_cache_hits,
5624 0
5625 );
5626 assert_eq!(
5627 initial_compile_metrics.specialization_lowering_cache_misses,
5628 1
5629 );
5630
5631 assert_eq!(evaluator.specialization_cache_len(), 1);
5632 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5633 assert_eq!(
5634 evaluator.expression_specialization_metrics(),
5635 ExpressionSpecializationMetrics {
5636 cache_hits: 0,
5637 cache_misses: 1,
5638 }
5639 );
5640 let all_active_cached_integrals = evaluator
5641 .expression_precomputed_cached_integrals()
5642 .expect("integrals should exist");
5643
5644 evaluator.isolate_many(&["p"]);
5645 assert_eq!(evaluator.specialization_cache_len(), 2);
5646 assert_eq!(
5647 evaluator.expression_specialization_metrics(),
5648 ExpressionSpecializationMetrics {
5649 cache_hits: 0,
5650 cache_misses: 2,
5651 }
5652 );
5653 let after_cache_miss_metrics = evaluator.expression_compile_metrics();
5654 assert_eq!(after_cache_miss_metrics.specialization_cache_hits, 0);
5655 assert_eq!(after_cache_miss_metrics.specialization_cache_misses, 1);
5656 assert_eq!(
5657 after_cache_miss_metrics.specialization_lowering_cache_hits,
5658 0
5659 );
5660 assert_eq!(
5661 after_cache_miss_metrics.specialization_lowering_cache_misses,
5662 2
5663 );
5664 assert!(after_cache_miss_metrics.specialization_ir_compile_nanos > 0);
5665 assert!(after_cache_miss_metrics.specialization_cached_integrals_nanos > 0);
5666 assert!(after_cache_miss_metrics.specialization_lowering_nanos > 0);
5667 assert!(evaluator
5668 .expression_precomputed_cached_integrals()
5669 .expect("integrals should exist")
5670 .is_empty());
5671
5672 evaluator.activate_many(&["k", "m"]);
5673 assert_eq!(evaluator.specialization_cache_len(), 2);
5674 assert_eq!(
5675 evaluator.expression_specialization_metrics(),
5676 ExpressionSpecializationMetrics {
5677 cache_hits: 1,
5678 cache_misses: 2,
5679 }
5680 );
5681 assert_eq!(
5682 evaluator
5683 .expression_precomputed_cached_integrals()
5684 .expect("integrals should exist"),
5685 all_active_cached_integrals
5686 );
5687 let after_cache_hit_metrics = evaluator.expression_compile_metrics();
5688 assert_eq!(after_cache_hit_metrics.specialization_cache_hits, 1);
5689 assert_eq!(after_cache_hit_metrics.specialization_cache_misses, 1);
5690 assert_eq!(
5691 after_cache_hit_metrics.specialization_lowering_cache_hits,
5692 0
5693 );
5694 assert_eq!(
5695 after_cache_hit_metrics.specialization_lowering_cache_misses,
5696 2
5697 );
5698 assert!(after_cache_hit_metrics.specialization_cache_restore_nanos > 0);
5699 }
5700
5701 #[test]
5702 fn test_weighted_sums_match_baseline_after_activation_changes() {
5703 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
5704 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
5705 let c1 = CacheOnlyScalar::new("c1").unwrap();
5706 let c2 = CacheOnlyScalar::new("c2").unwrap();
5707 let c3 = CacheOnlyScalar::new("c3").unwrap();
5708 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
5709 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
5710 let dataset = Arc::new(test_dataset());
5711 let evaluator = expr.load(&dataset).unwrap();
5712 let params = vec![0.2, -0.3, 1.1, -0.7];
5713
5714 evaluator.isolate_many(&["p1", "c1", "m1", "c3"]);
5715
5716 let expected_value = evaluator
5717 .evaluate_local(¶ms)
5718 .expect("evaluation should succeed")
5719 .iter()
5720 .zip(dataset.weights_local().iter())
5721 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
5722 assert_relative_eq!(
5723 evaluator
5724 .evaluate_weighted_value_sum_local(¶ms)
5725 .expect("evaluation should succeed"),
5726 expected_value,
5727 epsilon = 1e-10
5728 );
5729 }
5730
5731 #[test]
5732 fn test_evaluate_local_does_not_depend_on_dataset_rows() {
5733 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5734 .unwrap()
5735 .norm_sqr();
5736 let mut event1 = test_event();
5737 event1.p4s[0].t = 7.5;
5738 let mut event2 = test_event();
5739 event2.p4s[0].t = 8.25;
5740 let mut event3 = test_event();
5741 event3.p4s[0].t = 9.0;
5742 let dataset = Arc::new(Dataset::new_with_metadata(
5743 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5744 Arc::new(DatasetMetadata::default()),
5745 ));
5746 let mut evaluator = expr.load(&dataset).unwrap();
5747 drop(dataset);
5748 let expected_len = evaluator.resources.read().caches.len();
5749 Arc::get_mut(&mut evaluator.dataset)
5750 .expect("evaluator should own dataset Arc in this test")
5751 .clear_events_local();
5752 let cached = evaluator
5753 .evaluate_local(&[1.25, -0.75])
5754 .expect("evaluation should succeed");
5755 assert_eq!(cached.len(), expected_len);
5756 }
5757
5758 #[test]
5759 fn test_evaluate_gradient_local_does_not_depend_on_dataset_rows() {
5760 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5761 .unwrap()
5762 .norm_sqr();
5763 let mut event1 = test_event();
5764 event1.p4s[0].t = 7.5;
5765 let mut event2 = test_event();
5766 event2.p4s[0].t = 8.25;
5767 let mut event3 = test_event();
5768 event3.p4s[0].t = 9.0;
5769 let dataset = Arc::new(Dataset::new_with_metadata(
5770 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5771 Arc::new(DatasetMetadata::default()),
5772 ));
5773 let mut evaluator = expr.load(&dataset).unwrap();
5774 drop(dataset);
5775 let expected_len = evaluator.resources.read().caches.len();
5776 Arc::get_mut(&mut evaluator.dataset)
5777 .expect("evaluator should own dataset Arc in this test")
5778 .clear_events_local();
5779 let cached = evaluator
5780 .evaluate_gradient_local(&[1.25, -0.75])
5781 .expect("evaluation should succeed");
5782 assert_eq!(cached.len(), expected_len);
5783 }
5784
5785 #[test]
5786 fn test_evaluate_with_gradient_local_matches_separate_paths() {
5787 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5788 .unwrap()
5789 .norm_sqr();
5790 let dataset = Arc::new(Dataset::new(vec![
5791 Arc::new(test_event()),
5792 Arc::new(test_event()),
5793 Arc::new(test_event()),
5794 ]));
5795 let evaluator = expr.load(&dataset).unwrap();
5796 let params = [1.25, -0.75];
5797 let values = evaluator
5798 .evaluate_local(¶ms)
5799 .expect("evaluation should succeed");
5800 let gradients = evaluator
5801 .evaluate_gradient_local(¶ms)
5802 .expect("evaluation should succeed");
5803 let fused = evaluator
5804 .evaluate_with_gradient_local(¶ms)
5805 .expect("evaluation should succeed");
5806 assert_eq!(fused.len(), values.len());
5807 assert_eq!(fused.len(), gradients.len());
5808 for ((value_gradient, value), gradient) in
5809 fused.iter().zip(values.iter()).zip(gradients.iter())
5810 {
5811 let (fused_value, fused_gradient) = value_gradient;
5812 assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
5813 assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
5814 assert_eq!(fused_gradient.len(), gradient.len());
5815 for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
5816 assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
5817 assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
5818 }
5819 }
5820 }
5821
5822 #[test]
5823 fn test_evaluate_with_gradient_batch_local_matches_separate_paths() {
5824 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
5825 .unwrap()
5826 .norm_sqr();
5827 let dataset = Arc::new(Dataset::new(vec![
5828 Arc::new(test_event()),
5829 Arc::new(test_event()),
5830 Arc::new(test_event()),
5831 Arc::new(test_event()),
5832 ]));
5833 let evaluator = expr.load(&dataset).unwrap();
5834 let params = [0.5, -1.25];
5835 let indices = vec![0, 2, 3];
5836 let values = evaluator
5837 .evaluate_batch_local(¶ms, &indices)
5838 .expect("evaluation should succeed");
5839 let gradients = evaluator
5840 .evaluate_gradient_batch_local(¶ms, &indices)
5841 .expect("evaluation should succeed");
5842 let fused = evaluator
5843 .evaluate_with_gradient_batch_local(¶ms, &indices)
5844 .expect("evaluation should succeed");
5845 assert_eq!(fused.len(), values.len());
5846 assert_eq!(fused.len(), gradients.len());
5847 for ((value_gradient, value), gradient) in
5848 fused.iter().zip(values.iter()).zip(gradients.iter())
5849 {
5850 let (fused_value, fused_gradient) = value_gradient;
5851 assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
5852 assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
5853 assert_eq!(fused_gradient.len(), gradient.len());
5854 for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
5855 assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
5856 assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
5857 }
5858 }
5859 }
5860
5861 #[test]
5862 fn test_precompute_all_columnar_populates_cache() {
5863 let mut event1 = test_event();
5864 event1.p4s[0].t = 7.5;
5865 let mut event2 = test_event();
5866 event2.p4s[0].t = 8.25;
5867 let mut event3 = test_event();
5868 event3.p4s[0].t = 9.0;
5869 let dataset = Dataset::new_with_metadata(
5870 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5871 Arc::new(DatasetMetadata::default()),
5872 );
5873 let mut amplitude = TestAmplitude {
5874 tags: Tags::new(["test"]),
5875 re: parameter!("real"),
5876 pid_re: ParameterID::default(),
5877 im: parameter!("imag"),
5878 pid_im: ParameterID::default(),
5879 beam_energy: Default::default(),
5880 };
5881 let mut resources = Resources::default();
5882 amplitude
5883 .register(&mut resources)
5884 .expect("test amplitude should register");
5885 resources.reserve_cache(dataset.n_events());
5886 amplitude.precompute_all(&dataset, &mut resources);
5887 for cache in &resources.caches {
5888 assert!(cache.get_scalar(amplitude.beam_energy) > 0.0);
5889 }
5890 }
5891
5892 #[cfg(feature = "mpi")]
5893 #[mpi_test(np = [2])]
5894 fn test_load_reserves_local_cache_size_in_mpi() {
5895 use crate::mpi::{finalize_mpi, get_world, use_mpi};
5896
5897 use_mpi(true);
5898 assert!(get_world().is_some(), "MPI world should be initialized");
5899
5900 let expr = ComplexScalar::new(
5901 "constant",
5902 parameter!("const_re", 2.0),
5903 parameter!("const_im", 3.0),
5904 )
5905 .expect("constant amplitude should construct");
5906 let events = vec![
5907 Arc::new(test_event()),
5908 Arc::new(test_event()),
5909 Arc::new(test_event()),
5910 Arc::new(test_event()),
5911 ];
5912 let dataset = Arc::new(Dataset::new_with_metadata(
5913 events,
5914 Arc::new(DatasetMetadata::default()),
5915 ));
5916 let evaluator = expr.load(&dataset).expect("evaluator should load");
5917 let local_events = dataset.n_events_local();
5918 let cache_len = evaluator.resources.read().caches.len();
5919
5920 assert_eq!(
5921 cache_len, local_events,
5922 "cache length must match local event count under MPI"
5923 );
5924 finalize_mpi();
5925 }
5926
5927 #[cfg(feature = "mpi")]
5928 #[mpi_test(np = [2])]
5929 fn test_expression_ir_cached_integrals_are_rank_local_in_mpi() {
5930 use mpi::{collective::SystemOperation, topology::Communicator, traits::*};
5931
5932 use crate::mpi::{finalize_mpi, get_world, use_mpi};
5933
5934 use_mpi(true);
5935 let world = get_world().expect("MPI world should be initialized");
5936
5937 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5938 * &CacheOnlyScalar::new("k").unwrap();
5939 let events = vec![
5940 Arc::new(EventData {
5941 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
5942 aux: vec![],
5943 weight: 0.5,
5944 }),
5945 Arc::new(EventData {
5946 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5947 aux: vec![],
5948 weight: 1.0,
5949 }),
5950 Arc::new(EventData {
5951 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5952 aux: vec![],
5953 weight: 1.5,
5954 }),
5955 Arc::new(EventData {
5956 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
5957 aux: vec![],
5958 weight: 2.0,
5959 }),
5960 Arc::new(EventData {
5961 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5962 aux: vec![],
5963 weight: 2.5,
5964 }),
5965 Arc::new(EventData {
5966 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
5967 aux: vec![],
5968 weight: 3.0,
5969 }),
5970 Arc::new(EventData {
5971 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
5972 aux: vec![],
5973 weight: 3.5,
5974 }),
5975 Arc::new(EventData {
5976 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
5977 aux: vec![],
5978 weight: 4.0,
5979 }),
5980 ];
5981 let dataset = Arc::new(Dataset::new_with_metadata(
5982 events,
5983 Arc::new(DatasetMetadata::default()),
5984 ));
5985 let evaluator = expr.load(&dataset).expect("evaluator should load");
5986 let cached_integrals = evaluator
5987 .expression_precomputed_cached_integrals()
5988 .expect("integrals should exist");
5989 assert_eq!(cached_integrals.len(), 1);
5990
5991 let local_expected =
5992 dataset
5993 .weights_local()
5994 .iter()
5995 .enumerate()
5996 .fold(0.0, |acc, (index, weight)| {
5997 let event = dataset.event_local(index).expect("event should exist");
5998 acc + *weight * event.p4_at(0).e()
5999 });
6000 let cached_local = cached_integrals[0].weighted_cache_sum;
6001 assert_relative_eq!(cached_local.re, local_expected, epsilon = 1e-12);
6002 assert_relative_eq!(cached_local.im, 0.0, epsilon = 1e-12);
6003
6004 let weighted_value_sum = evaluator
6005 .evaluate_weighted_value_sum_local(&[2.0])
6006 .expect("evaluate should succeed");
6007 assert_relative_eq!(weighted_value_sum, 2.0 * local_expected, epsilon = 1e-10);
6008
6009 let mut global_expected = 0.0;
6010 world.all_reduce_into(
6011 &local_expected,
6012 &mut global_expected,
6013 SystemOperation::sum(),
6014 );
6015 if world.size() > 1 {
6016 assert!(
6017 (cached_local.re - global_expected).abs() > 1e-12,
6018 "cached integral should remain rank-local before MPI reduction"
6019 );
6020 }
6021 finalize_mpi();
6022 }
6023
6024 #[cfg(feature = "mpi")]
6025 #[mpi_test(np = [2])]
6026 fn test_expression_ir_weighted_sum_mpi_matches_global_eventwise_baseline() {
6027 use mpi::{collective::SystemOperation, traits::*};
6028
6029 use crate::mpi::{finalize_mpi, get_world, use_mpi};
6030
6031 use_mpi(true);
6032 let world = get_world().expect("MPI world should be initialized");
6033
6034 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6035 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6036 let c1 = CacheOnlyScalar::new("c1").unwrap();
6037 let c2 = CacheOnlyScalar::new("c2").unwrap();
6038 let c3 = CacheOnlyScalar::new("c3").unwrap();
6039 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6040 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6041 let events = vec![
6042 Arc::new(EventData {
6043 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
6044 aux: vec![],
6045 weight: 0.5,
6046 }),
6047 Arc::new(EventData {
6048 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6049 aux: vec![],
6050 weight: -1.25,
6051 }),
6052 Arc::new(EventData {
6053 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6054 aux: vec![],
6055 weight: 0.75,
6056 }),
6057 Arc::new(EventData {
6058 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
6059 aux: vec![],
6060 weight: 1.5,
6061 }),
6062 Arc::new(EventData {
6063 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6064 aux: vec![],
6065 weight: 2.25,
6066 }),
6067 Arc::new(EventData {
6068 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
6069 aux: vec![],
6070 weight: -0.5,
6071 }),
6072 Arc::new(EventData {
6073 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
6074 aux: vec![],
6075 weight: 3.5,
6076 }),
6077 Arc::new(EventData {
6078 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
6079 aux: vec![],
6080 weight: 1.25,
6081 }),
6082 ];
6083 let dataset = Arc::new(Dataset::new_with_metadata(
6084 events,
6085 Arc::new(DatasetMetadata::default()),
6086 ));
6087 let evaluator = expr.load(&dataset).expect("evaluator should load");
6088 let params = vec![0.2, -0.3, 1.1, -0.7];
6089
6090 let local_expected_value = evaluator
6091 .evaluate_local(¶ms)
6092 .expect("evaluate should succeed")
6093 .iter()
6094 .zip(dataset.weights_local().iter())
6095 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6096 let mut global_expected_value = 0.0;
6097 world.all_reduce_into(
6098 &local_expected_value,
6099 &mut global_expected_value,
6100 SystemOperation::sum(),
6101 );
6102 let mpi_value = evaluator
6103 .evaluate_weighted_value_sum_mpi(¶ms, &world)
6104 .expect("evaluate should succeed");
6105 assert_relative_eq!(mpi_value, global_expected_value, epsilon = 1e-10);
6106
6107 let local_expected_gradient = evaluator
6108 .evaluate_gradient_local(¶ms)
6109 .expect("evaluate should succeed")
6110 .iter()
6111 .zip(dataset.weights_local().iter())
6112 .fold(
6113 DVector::zeros(params.len()),
6114 |mut accum, (gradient, event)| {
6115 accum += gradient.map(|value| value.re).scale(*event);
6116 accum
6117 },
6118 );
6119 let mut global_expected_gradient = vec![0.0; local_expected_gradient.len()];
6120 world.all_reduce_into(
6121 local_expected_gradient.as_slice(),
6122 &mut global_expected_gradient,
6123 SystemOperation::sum(),
6124 );
6125 let mpi_gradient = evaluator
6126 .evaluate_weighted_gradient_sum_mpi(¶ms, &world)
6127 .expect("evaluate should succeed");
6128 for (actual, expected) in mpi_gradient.iter().zip(global_expected_gradient.iter()) {
6129 assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
6130 }
6131
6132 finalize_mpi();
6133 }
6134
6135 #[test]
6136 fn test_evaluate_local_succeeds_for_constant_amplitude() {
6137 let expr = ComplexScalar::new(
6138 "constant",
6139 parameter!("const_re", 2.0),
6140 parameter!("const_im", 3.0),
6141 )
6142 .unwrap();
6143 let dataset = Arc::new(Dataset::new_with_metadata(
6144 vec![Arc::new(test_event())],
6145 Arc::new(DatasetMetadata::default()),
6146 ));
6147 let evaluator = expr.load(&dataset).unwrap();
6148 let values = evaluator
6149 .evaluate_local(&[])
6150 .expect("evaluation should succeed");
6151 assert_eq!(values.len(), 1);
6152 let gradients = evaluator
6153 .evaluate_gradient_local(&[])
6154 .expect("evaluation should succeed");
6155 assert_eq!(gradients.len(), 1);
6156 }
6157
6158 #[test]
6159 fn test_constant_amplitude() {
6160 let expr = ComplexScalar::new(
6161 "constant",
6162 parameter!("const_re", 2.0),
6163 parameter!("const_im", 3.0),
6164 )
6165 .unwrap();
6166 let dataset = Arc::new(Dataset::new_with_metadata(
6167 vec![Arc::new(test_event())],
6168 Arc::new(DatasetMetadata::default()),
6169 ));
6170 let evaluator = expr.load(&dataset).unwrap();
6171 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6172 assert_eq!(result[0], Complex64::new(2.0, 3.0));
6173 }
6174
6175 #[test]
6176 fn test_parametric_amplitude() {
6177 let expr = ComplexScalar::new(
6178 "parametric",
6179 parameter!("test_param_re"),
6180 parameter!("test_param_im"),
6181 )
6182 .unwrap();
6183 let dataset = Arc::new(test_dataset());
6184 let evaluator = expr.load(&dataset).unwrap();
6185 let result = evaluator
6186 .evaluate(&[2.0, 3.0])
6187 .expect("evaluation should succeed");
6188 assert_eq!(result[0], Complex64::new(2.0, 3.0));
6189 }
6190
6191 #[test]
6192 fn test_expression_operations() {
6193 let expr1 = ComplexScalar::new(
6194 "const1",
6195 parameter!("const1_re", 2.0),
6196 parameter!("const1_im", 0.0),
6197 )
6198 .unwrap();
6199 let expr2 = ComplexScalar::new(
6200 "const2",
6201 parameter!("const2_re", 0.0),
6202 parameter!("const2_im", 1.0),
6203 )
6204 .unwrap();
6205 let expr3 = ComplexScalar::new(
6206 "const3",
6207 parameter!("const3_re", 3.0),
6208 parameter!("const3_im", 4.0),
6209 )
6210 .unwrap();
6211
6212 let dataset = Arc::new(test_dataset());
6213
6214 let expr_add = &expr1 + &expr2;
6216 let result_add = expr_add
6217 .load(&dataset)
6218 .unwrap()
6219 .evaluate(&[])
6220 .expect("evaluation should succeed");
6221 assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
6222
6223 let expr_sub = &expr1 - &expr2;
6225 let result_sub = expr_sub
6226 .load(&dataset)
6227 .unwrap()
6228 .evaluate(&[])
6229 .expect("evaluation should succeed");
6230 assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
6231
6232 let expr_mul = &expr1 * &expr2;
6234 let result_mul = expr_mul
6235 .load(&dataset)
6236 .unwrap()
6237 .evaluate(&[])
6238 .expect("evaluation should succeed");
6239 assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
6240
6241 let expr_div = &expr1 / &expr3;
6243 let result_div = expr_div
6244 .load(&dataset)
6245 .unwrap()
6246 .evaluate(&[])
6247 .expect("evaluation should succeed");
6248 assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
6249
6250 let expr_neg = -&expr3;
6252 let result_neg = expr_neg
6253 .load(&dataset)
6254 .unwrap()
6255 .evaluate(&[])
6256 .expect("evaluation should succeed");
6257 assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
6258
6259 let expr_add2 = &expr_add + &expr_mul;
6261 let result_add2 = expr_add2
6262 .load(&dataset)
6263 .unwrap()
6264 .evaluate(&[])
6265 .expect("evaluation should succeed");
6266 assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
6267
6268 let expr_sub2 = &expr_add - &expr_mul;
6270 let result_sub2 = expr_sub2
6271 .load(&dataset)
6272 .unwrap()
6273 .evaluate(&[])
6274 .expect("evaluation should succeed");
6275 assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
6276
6277 let expr_mul2 = &expr_add * &expr_mul;
6279 let result_mul2 = expr_mul2
6280 .load(&dataset)
6281 .unwrap()
6282 .evaluate(&[])
6283 .expect("evaluation should succeed");
6284 assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
6285
6286 let expr_div2 = &expr_add / &expr_add2;
6288 let result_div2 = expr_div2
6289 .load(&dataset)
6290 .unwrap()
6291 .evaluate(&[])
6292 .expect("evaluation should succeed");
6293 assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
6294
6295 let expr_neg2 = -&expr_mul2;
6297 let result_neg2 = expr_neg2
6298 .load(&dataset)
6299 .unwrap()
6300 .evaluate(&[])
6301 .expect("evaluation should succeed");
6302 assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
6303
6304 let expr_real = expr3.real();
6306 let result_real = expr_real
6307 .load(&dataset)
6308 .unwrap()
6309 .evaluate(&[])
6310 .expect("evaluation should succeed");
6311 assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
6312
6313 let expr_mul2_real = expr_mul2.real();
6315 let result_mul2_real = expr_mul2_real
6316 .load(&dataset)
6317 .unwrap()
6318 .evaluate(&[])
6319 .expect("evaluation should succeed");
6320 assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
6321
6322 let expr_imag = expr3.imag();
6324 let result_imag = expr_imag
6325 .load(&dataset)
6326 .unwrap()
6327 .evaluate(&[])
6328 .expect("evaluation should succeed");
6329 assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
6330
6331 let expr_mul2_imag = expr_mul2.imag();
6333 let result_mul2_imag = expr_mul2_imag
6334 .load(&dataset)
6335 .unwrap()
6336 .evaluate(&[])
6337 .expect("evaluation should succeed");
6338 assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
6339
6340 let expr_conj = expr3.conj();
6342 let result_conj = expr_conj
6343 .load(&dataset)
6344 .unwrap()
6345 .evaluate(&[])
6346 .expect("evaluation should succeed");
6347 assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
6348
6349 let expr_mul2_conj = expr_mul2.conj();
6351 let result_mul2_conj = expr_mul2_conj
6352 .load(&dataset)
6353 .unwrap()
6354 .evaluate(&[])
6355 .expect("evaluation should succeed");
6356 assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
6357
6358 let expr_norm = expr1.norm_sqr();
6360 let result_norm = expr_norm
6361 .load(&dataset)
6362 .unwrap()
6363 .evaluate(&[])
6364 .expect("evaluation should succeed");
6365 assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
6366
6367 let expr_mul2_norm = expr_mul2.norm_sqr();
6369 let result_mul2_norm = expr_mul2_norm
6370 .load(&dataset)
6371 .unwrap()
6372 .evaluate(&[])
6373 .expect("evaluation should succeed");
6374 assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
6375 }
6376
6377 #[test]
6378 fn test_amplitude_activation() {
6379 let expr1 = ComplexScalar::new(
6380 "const1",
6381 parameter!("const1_re_act", 1.0),
6382 parameter!("const1_im_act", 0.0),
6383 )
6384 .unwrap();
6385 let expr2 = ComplexScalar::new(
6386 "const2",
6387 parameter!("const2_re_act", 2.0),
6388 parameter!("const2_im_act", 0.0),
6389 )
6390 .unwrap();
6391
6392 let dataset = Arc::new(test_dataset());
6393 let expr = &expr1 + &expr2;
6394 let evaluator = expr.load(&dataset).unwrap();
6395
6396 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6398 assert_eq!(result[0], Complex64::new(3.0, 0.0));
6399
6400 evaluator.deactivate_strict("const1").unwrap();
6402 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6403 assert_eq!(result[0], Complex64::new(2.0, 0.0));
6404
6405 evaluator.isolate_strict("const1").unwrap();
6407 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6408 assert_eq!(result[0], Complex64::new(1.0, 0.0));
6409
6410 evaluator.activate_all();
6412 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6413 assert_eq!(result[0], Complex64::new(3.0, 0.0));
6414 }
6415
6416 #[test]
6417 fn test_gradient() {
6418 let expr1 = ComplexScalar::new(
6419 "parametric_1",
6420 parameter!("test_param_re_1"),
6421 parameter!("test_param_im_1"),
6422 )
6423 .unwrap();
6424 let expr2 = ComplexScalar::new(
6425 "parametric_2",
6426 parameter!("test_param_re_2"),
6427 parameter!("test_param_im_2"),
6428 )
6429 .unwrap();
6430
6431 let dataset = Arc::new(test_dataset());
6432 let params = vec![2.0, 3.0, 4.0, 5.0];
6433
6434 let expr = &expr1 + &expr2;
6435 let evaluator = expr.load(&dataset).unwrap();
6436
6437 let gradient = evaluator
6438 .evaluate_gradient(¶ms)
6439 .expect("evaluation should succeed");
6440
6441 assert_relative_eq!(gradient[0][0].re, 1.0);
6442 assert_relative_eq!(gradient[0][0].im, 0.0);
6443 assert_relative_eq!(gradient[0][1].re, 0.0);
6444 assert_relative_eq!(gradient[0][1].im, 1.0);
6445 assert_relative_eq!(gradient[0][2].re, 1.0);
6446 assert_relative_eq!(gradient[0][2].im, 0.0);
6447 assert_relative_eq!(gradient[0][3].re, 0.0);
6448 assert_relative_eq!(gradient[0][3].im, 1.0);
6449
6450 let expr = &expr1 - &expr2;
6451 let evaluator = expr.load(&dataset).unwrap();
6452
6453 let gradient = evaluator
6454 .evaluate_gradient(¶ms)
6455 .expect("evaluation should succeed");
6456
6457 assert_relative_eq!(gradient[0][0].re, 1.0);
6458 assert_relative_eq!(gradient[0][0].im, 0.0);
6459 assert_relative_eq!(gradient[0][1].re, 0.0);
6460 assert_relative_eq!(gradient[0][1].im, 1.0);
6461 assert_relative_eq!(gradient[0][2].re, -1.0);
6462 assert_relative_eq!(gradient[0][2].im, 0.0);
6463 assert_relative_eq!(gradient[0][3].re, 0.0);
6464 assert_relative_eq!(gradient[0][3].im, -1.0);
6465
6466 let expr = &expr1 * &expr2;
6467 let evaluator = expr.load(&dataset).unwrap();
6468
6469 let gradient = evaluator
6470 .evaluate_gradient(¶ms)
6471 .expect("evaluation should succeed");
6472
6473 assert_relative_eq!(gradient[0][0].re, 4.0);
6474 assert_relative_eq!(gradient[0][0].im, 5.0);
6475 assert_relative_eq!(gradient[0][1].re, -5.0);
6476 assert_relative_eq!(gradient[0][1].im, 4.0);
6477 assert_relative_eq!(gradient[0][2].re, 2.0);
6478 assert_relative_eq!(gradient[0][2].im, 3.0);
6479 assert_relative_eq!(gradient[0][3].re, -3.0);
6480 assert_relative_eq!(gradient[0][3].im, 2.0);
6481
6482 let expr = &expr1 / &expr2;
6483 let evaluator = expr.load(&dataset).unwrap();
6484
6485 let gradient = evaluator
6486 .evaluate_gradient(¶ms)
6487 .expect("evaluation should succeed");
6488
6489 assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
6490 assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
6491 assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
6492 assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
6493 assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
6494 assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
6495 assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
6496 assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
6497
6498 let expr = -(&expr1 * &expr2);
6499 let evaluator = expr.load(&dataset).unwrap();
6500
6501 let gradient = evaluator
6502 .evaluate_gradient(¶ms)
6503 .expect("evaluation should succeed");
6504
6505 assert_relative_eq!(gradient[0][0].re, -4.0);
6506 assert_relative_eq!(gradient[0][0].im, -5.0);
6507 assert_relative_eq!(gradient[0][1].re, 5.0);
6508 assert_relative_eq!(gradient[0][1].im, -4.0);
6509 assert_relative_eq!(gradient[0][2].re, -2.0);
6510 assert_relative_eq!(gradient[0][2].im, -3.0);
6511 assert_relative_eq!(gradient[0][3].re, 3.0);
6512 assert_relative_eq!(gradient[0][3].im, -2.0);
6513
6514 let expr = (&expr1 * &expr2).real();
6515 let evaluator = expr.load(&dataset).unwrap();
6516
6517 let gradient = evaluator
6518 .evaluate_gradient(¶ms)
6519 .expect("evaluation should succeed");
6520
6521 assert_relative_eq!(gradient[0][0].re, 4.0);
6522 assert_relative_eq!(gradient[0][0].im, 0.0);
6523 assert_relative_eq!(gradient[0][1].re, -5.0);
6524 assert_relative_eq!(gradient[0][1].im, 0.0);
6525 assert_relative_eq!(gradient[0][2].re, 2.0);
6526 assert_relative_eq!(gradient[0][2].im, 0.0);
6527 assert_relative_eq!(gradient[0][3].re, -3.0);
6528 assert_relative_eq!(gradient[0][3].im, 0.0);
6529
6530 let expr = (&expr1 * &expr2).imag();
6531 let evaluator = expr.load(&dataset).unwrap();
6532
6533 let gradient = evaluator
6534 .evaluate_gradient(¶ms)
6535 .expect("evaluation should succeed");
6536
6537 assert_relative_eq!(gradient[0][0].re, 5.0);
6538 assert_relative_eq!(gradient[0][0].im, 0.0);
6539 assert_relative_eq!(gradient[0][1].re, 4.0);
6540 assert_relative_eq!(gradient[0][1].im, 0.0);
6541 assert_relative_eq!(gradient[0][2].re, 3.0);
6542 assert_relative_eq!(gradient[0][2].im, 0.0);
6543 assert_relative_eq!(gradient[0][3].re, 2.0);
6544 assert_relative_eq!(gradient[0][3].im, 0.0);
6545
6546 let expr = (&expr1 * &expr2).conj();
6547 let evaluator = expr.load(&dataset).unwrap();
6548
6549 let gradient = evaluator
6550 .evaluate_gradient(¶ms)
6551 .expect("evaluation should succeed");
6552
6553 assert_relative_eq!(gradient[0][0].re, 4.0);
6554 assert_relative_eq!(gradient[0][0].im, -5.0);
6555 assert_relative_eq!(gradient[0][1].re, -5.0);
6556 assert_relative_eq!(gradient[0][1].im, -4.0);
6557 assert_relative_eq!(gradient[0][2].re, 2.0);
6558 assert_relative_eq!(gradient[0][2].im, -3.0);
6559 assert_relative_eq!(gradient[0][3].re, -3.0);
6560 assert_relative_eq!(gradient[0][3].im, -2.0);
6561
6562 let expr = (&expr1 * &expr2).norm_sqr();
6563 let evaluator = expr.load(&dataset).unwrap();
6564
6565 let gradient = evaluator
6566 .evaluate_gradient(¶ms)
6567 .expect("evaluation should succeed");
6568
6569 assert_relative_eq!(gradient[0][0].re, 164.0);
6570 assert_relative_eq!(gradient[0][0].im, 0.0);
6571 assert_relative_eq!(gradient[0][1].re, 246.0);
6572 assert_relative_eq!(gradient[0][1].im, 0.0);
6573 assert_relative_eq!(gradient[0][2].re, 104.0);
6574 assert_relative_eq!(gradient[0][2].im, 0.0);
6575 assert_relative_eq!(gradient[0][3].re, 130.0);
6576 assert_relative_eq!(gradient[0][3].im, 0.0);
6577 }
6578
6579 #[test]
6580 fn test_expression_function_gradients() {
6581 let expr1 = ComplexScalar::new(
6582 "function_parametric_1",
6583 parameter!("function_test_param_re_1"),
6584 parameter!("function_test_param_im_1"),
6585 )
6586 .unwrap();
6587 let expr2 = ComplexScalar::new(
6588 "function_parametric_2",
6589 parameter!("function_test_param_re_2"),
6590 parameter!("function_test_param_im_2"),
6591 )
6592 .unwrap();
6593
6594 let sin = expr1.sin();
6595 let cos = expr1.cos();
6596 let trig = &sin * &cos;
6597 let pow = expr1.pow(&expr2);
6598 let mut expr = expr1.sqrt();
6599 expr = &expr + &expr1.exp();
6600 expr = &expr + &expr1.powi(2);
6601 expr = &expr + &expr1.powf(1.7);
6602 expr = &expr + &trig;
6603 expr = &expr + &expr1.log();
6604 expr = &expr + &expr1.cis();
6605 expr = &expr + &pow;
6606
6607 let dataset = Arc::new(test_dataset());
6608 let evaluator = expr.load(&dataset).unwrap();
6609 let params = vec![2.0, 0.5, 1.2, -0.3];
6610 let gradient = evaluator
6611 .evaluate_gradient(¶ms)
6612 .expect("evaluation should succeed");
6613 let eps = 1e-6;
6614
6615 for param_index in 0..params.len() {
6616 let mut plus = params.clone();
6617 plus[param_index] += eps;
6618 let mut minus = params.clone();
6619 minus[param_index] -= eps;
6620 let finite_diff = (evaluator
6621 .evaluate(&plus)
6622 .expect("evaluation should succeed")[0]
6623 - evaluator
6624 .evaluate(&minus)
6625 .expect("evaluation should succeed")[0])
6626 / Complex64::new(2.0 * eps, 0.0);
6627
6628 assert_relative_eq!(
6629 gradient[0][param_index].re,
6630 finite_diff.re,
6631 epsilon = 1e-6,
6632 max_relative = 1e-6
6633 );
6634 assert_relative_eq!(
6635 gradient[0][param_index].im,
6636 finite_diff.im,
6637 epsilon = 1e-6,
6638 max_relative = 1e-6
6639 );
6640 }
6641 }
6642
6643 #[test]
6644 fn test_zeros_and_ones() {
6645 let amp = ComplexScalar::new(
6646 "parametric",
6647 parameter!("test_param_re"),
6648 parameter!("fixed_two", 2.0),
6649 )
6650 .unwrap();
6651 let dataset = Arc::new(test_dataset());
6652 let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
6653 let evaluator = expr.load(&dataset).unwrap();
6654
6655 let params = vec![2.0];
6656 let value = evaluator
6657 .evaluate(¶ms)
6658 .expect("evaluation should succeed");
6659 let gradient = evaluator
6660 .evaluate_gradient(¶ms)
6661 .expect("evaluation should succeed");
6662
6663 assert_relative_eq!(value[0].re, 8.0);
6665 assert_relative_eq!(value[0].im, 0.0);
6666
6667 assert_relative_eq!(gradient[0][0].re, 4.0);
6669 assert_relative_eq!(gradient[0][0].im, 0.0);
6670 }
6671 #[test]
6672 fn test_default_build_uses_lowered_expression_runtime() {
6673 let expr = ComplexScalar::new(
6674 "opt_in_gate",
6675 parameter!("opt_in_gate_re", 2.0),
6676 parameter!("opt_in_gate_im", 0.0),
6677 )
6678 .unwrap()
6679 .norm_sqr();
6680 let dataset = Arc::new(test_dataset());
6681 let evaluator = expr.load(&dataset).unwrap();
6682
6683 let diagnostics = evaluator.expression_runtime_diagnostics();
6684 assert!(diagnostics.ir_planning_enabled);
6685 assert!(diagnostics.lowered_value_program_present);
6686 assert!(diagnostics.lowered_gradient_program_present);
6687 assert!(diagnostics.lowered_value_gradient_program_present);
6688 assert_eq!(
6689 evaluator.evaluate(&[]).expect("evaluation should succeed")[0],
6690 Complex64::new(4.0, 0.0)
6691 );
6692 }
6693
6694 #[test]
6695 fn parameter_name_only_creates_free_parameter() {
6696 let p = parameter!("mass");
6697
6698 assert_eq!(p.name(), "mass");
6699 assert_eq!(p.fixed(), None);
6700 assert_eq!(p.initial(), None);
6701 assert_eq!(p.bounds(), (None, None));
6702 assert_eq!(p.unit(), None);
6703 assert_eq!(p.latex(), None);
6704 assert_eq!(p.description(), None);
6705 assert!(p.is_free());
6706 assert!(!p.is_fixed());
6707 }
6708
6709 #[test]
6710 fn parameter_name_and_value_creates_fixed_parameter() {
6711 let p = parameter!("width", 0.15);
6712
6713 assert_eq!(p.name(), "width");
6714 assert_eq!(p.fixed(), Some(0.15));
6715 assert_eq!(p.initial(), Some(0.15));
6716 assert!(p.is_fixed());
6717 assert!(!p.is_free());
6718 }
6719
6720 #[test]
6721 fn keyword_initial_sets_initial_only() {
6722 let p = parameter!("alpha", initial: 1.25);
6723
6724 assert_eq!(p.name(), "alpha");
6725 assert_eq!(p.fixed(), None);
6726 assert_eq!(p.initial(), Some(1.25));
6727 assert_eq!(p.bounds(), (None, None));
6728 assert!(p.is_free());
6729 }
6730
6731 #[test]
6732 fn keyword_fixed_sets_fixed_and_initial() {
6733 let p = parameter!("beta", fixed: 2.5);
6734
6735 assert_eq!(p.name(), "beta");
6736 assert_eq!(p.fixed(), Some(2.5));
6737 assert_eq!(p.initial(), Some(2.5));
6738 assert!(p.is_fixed());
6739 }
6740
6741 #[test]
6742 fn bounds_accept_plain_numbers() {
6743 let p = parameter!("x", bounds: (0.0, 10.0));
6744
6745 assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
6746 }
6747
6748 #[test]
6749 fn bounds_accept_none_and_number() {
6750 let p = parameter!("x", bounds: (None, 10.0));
6751
6752 assert_eq!(p.bounds(), (None, Some(10.0)));
6753 }
6754
6755 #[test]
6756 fn bounds_accept_number_and_none() {
6757 let p = parameter!("x", bounds: (-1.0, None));
6758
6759 assert_eq!(p.bounds(), (Some(-1.0), None));
6760 }
6761
6762 #[test]
6763 fn bounds_accept_both_none() {
6764 let p = parameter!("x", bounds: (None, None));
6765
6766 assert_eq!(p.bounds(), (None, None));
6767 }
6768
6769 #[test]
6770 fn bounds_accept_arbitrary_expressions() {
6771 let lo = 1.0;
6772 let hi = 2.0 * 3.0;
6773 let p = parameter!("x", bounds: (lo - 0.5, hi));
6774
6775 assert_eq!(p.bounds(), (Some(0.5), Some(6.0)));
6776 }
6777
6778 #[test]
6779 fn multiple_keyword_arguments_work_together() {
6780 let p = parameter!(
6781 "gamma",
6782 initial: 1.0,
6783 bounds: (0.0, 5.0),
6784 unit: "GeV",
6785 latex: r"\gamma",
6786 description: "test parameter",
6787 );
6788
6789 assert_eq!(p.name(), "gamma");
6790 assert_eq!(p.fixed(), None);
6791 assert_eq!(p.initial(), Some(1.0));
6792 assert_eq!(p.bounds(), (Some(0.0), Some(5.0)));
6793 assert_eq!(p.unit().as_deref(), Some("GeV"));
6794 assert_eq!(p.latex().as_deref(), Some(r"\gamma"));
6795 assert_eq!(p.description().as_deref(), Some("test parameter"));
6796 }
6797
6798 #[test]
6799 fn fixed_can_be_combined_with_other_fields() {
6800 let p = parameter!(
6801 "delta",
6802 fixed: 3.0,
6803 bounds: (0.0, 10.0),
6804 unit: "rad",
6805 );
6806
6807 assert_eq!(p.name(), "delta");
6808 assert_eq!(p.fixed(), Some(3.0));
6809 assert_eq!(p.initial(), Some(3.0));
6810 assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
6811 assert_eq!(p.unit().as_deref(), Some("rad"));
6812 }
6813
6814 #[test]
6815 fn trailing_comma_is_accepted() {
6816 let p = parameter!(
6817 "eps",
6818 initial: 0.5,
6819 bounds: (None, 1.0),
6820 unit: "arb",
6821 );
6822
6823 assert_eq!(p.initial(), Some(0.5));
6824 assert_eq!(p.bounds(), (None, Some(1.0)));
6825 assert_eq!(p.unit().as_deref(), Some("arb"));
6826 }
6827
6828 #[test]
6829 fn test_parameter_registration() {
6830 let expr = ComplexScalar::new(
6831 "parametric",
6832 parameter!("test_param_re"),
6833 parameter!("fixed_two", 2.0),
6834 )
6835 .unwrap();
6836 let parameters = expr.parameters().free().names();
6837 assert_eq!(parameters.len(), 1);
6838 assert_eq!(parameters[0], "test_param_re");
6839 }
6840
6841 #[test]
6842 fn test_duplicate_amplitude_tag_registration_is_allowed() {
6843 let amp1 = ComplexScalar::new(
6844 "same_name",
6845 parameter!("dup_re1", 1.0),
6846 parameter!("dup_im1", 0.0),
6847 )
6848 .unwrap();
6849 let amp2 = ComplexScalar::new(
6850 "same_name",
6851 parameter!("dup_re2", 2.0),
6852 parameter!("dup_im2", 0.0),
6853 )
6854 .unwrap();
6855 let expr = amp1 + amp2;
6856 assert_eq!(
6857 expr.parameters().fixed().names(),
6858 vec!["dup_re1", "dup_im1", "dup_re2", "dup_im2"]
6859 );
6860 }
6861
6862 #[test]
6863 fn test_tree_printing() {
6864 let amp1 = ComplexScalar::new(
6865 "parametric_1",
6866 parameter!("test_param_re_1"),
6867 parameter!("test_param_im_1"),
6868 )
6869 .unwrap();
6870 let amp2 = ComplexScalar::new(
6871 "parametric_2",
6872 parameter!("test_param_re_2"),
6873 parameter!("test_param_im_2"),
6874 )
6875 .unwrap();
6876 let expr =
6877 &1.real() + &2.conj().imag() + Expression::one() * Complex64::new(-1.4, 2.0)
6878 - Expression::zero() / 1.0
6879 + (&1 * &2).norm_sqr();
6880 assert_eq!(
6881 expr.to_string(),
6882 concat!(
6883 "+\n",
6884 "├─ -\n",
6885 "│ ├─ +\n",
6886 "│ │ ├─ +\n",
6887 "│ │ │ ├─ Re\n",
6888 "│ │ │ │ └─ parametric_1(id=0)\n",
6889 "│ │ │ └─ Im\n",
6890 "│ │ │ └─ *\n",
6891 "│ │ │ └─ parametric_2(id=1)\n",
6892 "│ │ └─ ×\n",
6893 "│ │ ├─ 1 (exact)\n",
6894 "│ │ └─ -1.4+2i\n",
6895 "│ └─ ÷\n",
6896 "│ ├─ 0 (exact)\n",
6897 "│ └─ 1 (exact)\n",
6898 "└─ NormSqr\n",
6899 " └─ ×\n",
6900 " ├─ parametric_1(id=0)\n",
6901 " └─ parametric_2(id=1)\n",
6902 )
6903 );
6904 }
6905}