1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 sync::{
5 atomic::{AtomicU64, Ordering},
6 Arc,
7 },
8 time::Instant,
9};
10
11use auto_ops::*;
12use nalgebra::DVector;
13use num::complex::Complex64;
14use parking_lot::RwLock;
15#[cfg(feature = "rayon")]
16use rayon::prelude::*;
17use serde::{Deserialize, Serialize};
18
19use super::{ir, lowered};
20
21static COMPUTE_AMPLITUDE_COUNTER: AtomicU64 = AtomicU64::new(0);
22static AMPLITUDE_USE_SITE_COUNTER: AtomicU64 = AtomicU64::new(0);
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
25struct ComputeAmplitudeId(u64);
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
28struct AmplitudeUseSiteId(u64);
29
30fn next_compute_amplitude_id() -> ComputeAmplitudeId {
31 ComputeAmplitudeId(COMPUTE_AMPLITUDE_COUNTER.fetch_add(1, Ordering::Relaxed))
32}
33
34fn next_amplitude_use_site_id() -> AmplitudeUseSiteId {
35 AmplitudeUseSiteId(AMPLITUDE_USE_SITE_COUNTER.fetch_add(1, Ordering::Relaxed))
36}
37#[derive(Clone, Copy, Debug, PartialEq, Eq)]
38pub enum ExpressionDependence {
40 ParameterOnly,
42 CacheOnly,
44 Mixed,
46}
47impl From<ir::DependenceClass> for ExpressionDependence {
48 fn from(value: ir::DependenceClass) -> Self {
49 match value {
50 ir::DependenceClass::ParameterOnly => Self::ParameterOnly,
51 ir::DependenceClass::CacheOnly => Self::CacheOnly,
52 ir::DependenceClass::Mixed => Self::Mixed,
53 }
54 }
55}
56#[derive(Clone, Debug, PartialEq, Eq)]
57pub struct NormalizationPlanExplain {
59 pub root_dependence: ExpressionDependence,
61 pub warnings: Vec<String>,
63 pub separable_mul_candidate_nodes: Vec<usize>,
65 pub cached_separable_nodes: Vec<usize>,
67 pub residual_terms: Vec<usize>,
69}
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub struct NormalizationExecutionSetsExplain {
73 pub cached_parameter_amplitudes: Vec<usize>,
75 pub cached_cache_amplitudes: Vec<usize>,
77 pub residual_amplitudes: Vec<usize>,
79}
80#[derive(Clone, Debug, PartialEq)]
81pub struct PrecomputedCachedIntegral {
83 pub mul_node_index: usize,
85 pub parameter_node_index: usize,
87 pub cache_node_index: usize,
89 pub coefficient: i32,
91 pub weighted_cache_sum: Complex64,
93}
94#[derive(Clone, Debug, PartialEq)]
95pub struct PrecomputedCachedIntegralGradientTerm {
97 pub mul_node_index: usize,
99 pub parameter_node_index: usize,
101 pub cache_node_index: usize,
103 pub coefficient: i32,
105 pub weighted_gradient: DVector<Complex64>,
107}
108#[derive(Clone, Debug, PartialEq, Eq, Hash)]
109struct CachedIntegralCacheKey {
110 active_mask: Vec<bool>,
111 n_events_local: usize,
112 weights_local_len: usize,
113 weighted_sum_bits: u64,
114 weights_ptr: usize,
115}
116#[derive(Clone, Debug)]
117struct CachedIntegralCacheState {
118 key: CachedIntegralCacheKey,
119 expression_ir: ir::ExpressionIR,
120 values: Vec<PrecomputedCachedIntegral>,
121 execution_sets: ir::NormalizationExecutionSets,
122}
123#[derive(Clone, Debug)]
124struct LoweredArtifactCacheState {
125 parameter_node_indices: Vec<usize>,
126 mul_node_indices: Vec<usize>,
127 lowered_parameter_factors: Vec<Option<lowered::LoweredFactorRuntime>>,
128 residual_runtime: Option<lowered::LoweredExpressionRuntime>,
129 lowered_runtime: lowered::LoweredExpressionRuntime,
130}
131#[derive(Clone)]
132struct ExpressionSpecializationState {
133 cached_integrals: Arc<CachedIntegralCacheState>,
134 lowered_artifacts: Arc<LoweredArtifactCacheState>,
135}
136#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
137pub struct ExpressionSpecializationMetrics {
139 pub cache_hits: usize,
141 pub cache_misses: usize,
143}
144#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
145pub struct ExpressionCompileMetrics {
147 pub initial_ir_compile_nanos: u64,
149 pub initial_cached_integrals_nanos: u64,
151 pub initial_lowering_nanos: u64,
153 pub specialization_cache_hits: usize,
155 pub specialization_cache_misses: usize,
157 pub specialization_ir_compile_nanos: u64,
159 pub specialization_cached_integrals_nanos: u64,
161 pub specialization_lowering_nanos: u64,
163 pub specialization_lowering_cache_hits: usize,
165 pub specialization_lowering_cache_misses: usize,
167 pub specialization_cache_restore_nanos: u64,
169}
170impl From<ir::NormalizationPlanExplain> for NormalizationPlanExplain {
171 fn from(value: ir::NormalizationPlanExplain) -> Self {
172 Self {
173 root_dependence: value.root_dependence.into(),
174 warnings: value.warnings,
175 separable_mul_candidate_nodes: value
176 .separable_mul_candidates
177 .into_iter()
178 .map(|candidate| candidate.node_index)
179 .collect(),
180 cached_separable_nodes: value.cached_separable_nodes,
181 residual_terms: value.residual_terms,
182 }
183 }
184}
185impl From<ir::NormalizationExecutionSets> for NormalizationExecutionSetsExplain {
186 fn from(value: ir::NormalizationExecutionSets) -> Self {
187 Self {
188 cached_parameter_amplitudes: value.cached_parameter_amplitudes,
189 cached_cache_amplitudes: value.cached_cache_amplitudes,
190 residual_amplitudes: value.residual_amplitudes,
191 }
192 }
193}
194impl From<ExpressionDependence> for ir::DependenceClass {
195 fn from(value: ExpressionDependence) -> Self {
196 match value {
197 ExpressionDependence::ParameterOnly => Self::ParameterOnly,
198 ExpressionDependence::CacheOnly => Self::CacheOnly,
199 ExpressionDependence::Mixed => Self::Mixed,
200 }
201 }
202}
203
204#[cfg(feature = "mpi")]
205use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
206
207#[cfg(feature = "mpi")]
208use crate::mpi::LadduMPI;
209#[cfg(feature = "execution-context-prototype")]
210use crate::ExecutionContext;
211#[cfg(all(feature = "execution-context-prototype", feature = "rayon"))]
212use crate::ThreadPolicy;
213use crate::{
214 amplitude::{Amplitude, AmplitudeUseSite},
215 data::Dataset,
216 parameters::ParameterMap,
217 resources::{Cache, Parameters, Resources},
218 LadduError, LadduResult,
219};
220
221#[allow(missing_docs)]
223#[derive(Clone, Serialize, Deserialize, Default)]
224pub struct Expression {
225 registry: ExpressionRegistry,
226 tree: ExpressionNode,
227}
228
229#[derive(Clone, Serialize, Deserialize)]
230#[allow(missing_docs)]
231#[derive(Default)]
232pub struct ExpressionRegistry {
233 amplitudes: Vec<Box<dyn Amplitude>>,
234 amplitude_names: Vec<String>,
235 amplitude_ids: Vec<ComputeAmplitudeId>,
236 amplitude_use_sites: Vec<AmplitudeUseSite>,
237 amplitude_use_site_ids: Vec<AmplitudeUseSiteId>,
238 resources: Resources,
239}
240
241impl ExpressionRegistry {
242 fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
243 let mut resources = Resources::default();
244 let aid = amplitude.register(&mut resources)?;
245 let compute_id = next_compute_amplitude_id();
246 let use_site_id = next_amplitude_use_site_id();
247 resources.configure_amplitude_tags(std::slice::from_ref(&aid.0));
248 Ok(Self {
249 amplitudes: vec![amplitude],
250 amplitude_names: vec![aid.0.display_label()],
251 amplitude_ids: vec![compute_id],
252 amplitude_use_sites: vec![AmplitudeUseSite {
253 amplitude_index: 0,
254 tags: aid.0,
255 }],
256 amplitude_use_site_ids: vec![use_site_id],
257 resources,
258 })
259 }
260
261 fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
262 let mut resources = Resources::default();
263 let mut amplitudes = Vec::new();
264 let mut amplitude_ids = Vec::new();
265 let mut compute_id_to_index = HashMap::new();
266 let mut semantic_key_to_index = HashMap::new();
267
268 let mut left_compute_map = Vec::with_capacity(self.amplitudes.len());
269 for (amp_index, (amp, amp_id)) in
270 self.amplitudes.iter().zip(&self.amplitude_ids).enumerate()
271 {
272 let semantic_key = amp.semantic_key();
273 let mut cloned_amp = dyn_clone::clone_box(&**amp);
274 let aid = cloned_amp.register(&mut resources)?;
275 if let Some(key) = semantic_key.clone() {
276 semantic_key_to_index.insert(key, aid.1);
277 }
278 amplitudes.push(cloned_amp);
279 amplitude_ids.push(*amp_id);
280 compute_id_to_index.insert(*amp_id, aid.1);
281 left_compute_map.push(aid.1);
282 debug_assert_eq!(amp_index, aid.1);
283 }
284
285 let mut right_compute_map = Vec::with_capacity(other.amplitudes.len());
286 for (amp, amp_id) in other.amplitudes.iter().zip(&other.amplitude_ids) {
287 if let Some(existing) = compute_id_to_index.get(amp_id) {
288 right_compute_map.push(*existing);
289 continue;
290 }
291 let incoming_semantic_key = amp.semantic_key();
292 if let Some(existing) = incoming_semantic_key
293 .as_ref()
294 .and_then(|key| semantic_key_to_index.get(key))
295 {
296 right_compute_map.push(*existing);
297 continue;
298 }
299 let mut cloned_amp = dyn_clone::clone_box(&**amp);
300 let aid = cloned_amp.register(&mut resources)?;
301 if let Some(key) = incoming_semantic_key.clone() {
302 semantic_key_to_index.insert(key, aid.1);
303 }
304 amplitudes.push(cloned_amp);
305 amplitude_ids.push(*amp_id);
306 compute_id_to_index.insert(*amp_id, aid.1);
307 right_compute_map.push(aid.1);
308 }
309
310 let mut amplitude_use_sites = Vec::new();
311 let mut amplitude_use_site_ids = Vec::new();
312 let mut amplitude_names = Vec::new();
313 let mut use_site_id_to_index = HashMap::new();
314
315 let mut left_map = Vec::with_capacity(self.amplitude_use_sites.len());
316 for (use_site, use_site_id) in self
317 .amplitude_use_sites
318 .iter()
319 .zip(&self.amplitude_use_site_ids)
320 {
321 let mapped_index = left_compute_map[use_site.amplitude_index];
322 let new_index = amplitude_use_sites.len();
323 left_map.push(new_index);
324 use_site_id_to_index.insert(*use_site_id, new_index);
325 amplitude_use_site_ids.push(*use_site_id);
326 amplitude_names.push(use_site.tags.display_label());
327 amplitude_use_sites.push(AmplitudeUseSite {
328 amplitude_index: mapped_index,
329 tags: use_site.tags.clone(),
330 });
331 }
332
333 let mut right_map = Vec::with_capacity(other.amplitude_use_sites.len());
334 for (use_site, use_site_id) in other
335 .amplitude_use_sites
336 .iter()
337 .zip(&other.amplitude_use_site_ids)
338 {
339 if let Some(existing) = use_site_id_to_index.get(use_site_id) {
340 right_map.push(*existing);
341 continue;
342 }
343 let mapped_index = right_compute_map[use_site.amplitude_index];
344 let new_index = amplitude_use_sites.len();
345 right_map.push(new_index);
346 use_site_id_to_index.insert(*use_site_id, new_index);
347 amplitude_use_site_ids.push(*use_site_id);
348 amplitude_names.push(use_site.tags.display_label());
349 amplitude_use_sites.push(AmplitudeUseSite {
350 amplitude_index: mapped_index,
351 tags: use_site.tags.clone(),
352 });
353 }
354 let use_site_tags = amplitude_use_sites
355 .iter()
356 .map(|use_site| use_site.tags.clone())
357 .collect::<Vec<_>>();
358 resources.configure_amplitude_tags(&use_site_tags);
359
360 Ok((
361 Self {
362 amplitudes,
363 amplitude_names,
364 amplitude_ids,
365 amplitude_use_sites,
366 amplitude_use_site_ids,
367 resources,
368 },
369 left_map,
370 right_map,
371 ))
372 }
373}
374
375#[allow(missing_docs)]
377#[derive(Clone, Serialize, Deserialize, Default, Debug)]
378pub enum ExpressionNode {
379 #[default]
380 Zero,
382 One,
384 Constant(f64),
386 ComplexConstant(Complex64),
388 Amp(usize),
390 Add(Box<ExpressionNode>, Box<ExpressionNode>),
392 Sub(Box<ExpressionNode>, Box<ExpressionNode>),
394 Mul(Box<ExpressionNode>, Box<ExpressionNode>),
396 Div(Box<ExpressionNode>, Box<ExpressionNode>),
398 Neg(Box<ExpressionNode>),
400 Real(Box<ExpressionNode>),
402 Imag(Box<ExpressionNode>),
404 Conj(Box<ExpressionNode>),
406 NormSqr(Box<ExpressionNode>),
408 Sqrt(Box<ExpressionNode>),
409 Pow(Box<ExpressionNode>, Box<ExpressionNode>),
410 PowI(Box<ExpressionNode>, i32),
411 PowF(Box<ExpressionNode>, f64),
412 Exp(Box<ExpressionNode>),
413 Sin(Box<ExpressionNode>),
414 Cos(Box<ExpressionNode>),
415 Log(Box<ExpressionNode>),
416 Cis(Box<ExpressionNode>),
417}
418
419#[derive(Clone, Debug)]
420struct ExpressionProgram {
427 ops: Vec<ExpressionOp>,
428 slot_count: usize,
429 root_slot: usize,
430}
431
432#[derive(Clone, Debug)]
433enum ExpressionOp {
434 LoadZero {
435 dst: usize,
436 },
437 LoadOne {
438 dst: usize,
439 },
440 LoadConstant {
441 dst: usize,
442 value: f64,
443 },
444 LoadComplexConstant {
445 dst: usize,
446 value: Complex64,
447 },
448 LoadAmp {
449 dst: usize,
450 amp_idx: usize,
451 },
452 Add {
453 dst: usize,
454 left: usize,
455 right: usize,
456 },
457 Sub {
458 dst: usize,
459 left: usize,
460 right: usize,
461 },
462 Mul {
463 dst: usize,
464 left: usize,
465 right: usize,
466 },
467 Div {
468 dst: usize,
469 left: usize,
470 right: usize,
471 },
472 Neg {
473 dst: usize,
474 input: usize,
475 },
476 Real {
477 dst: usize,
478 input: usize,
479 },
480 Imag {
481 dst: usize,
482 input: usize,
483 },
484 Conj {
485 dst: usize,
486 input: usize,
487 },
488 NormSqr {
489 dst: usize,
490 input: usize,
491 },
492 Sqrt {
493 dst: usize,
494 input: usize,
495 },
496 Pow {
497 dst: usize,
498 value: usize,
499 power: usize,
500 },
501 PowI {
502 dst: usize,
503 input: usize,
504 power: i32,
505 },
506 PowF {
507 dst: usize,
508 input: usize,
509 power: f64,
510 },
511 Exp {
512 dst: usize,
513 input: usize,
514 },
515 Sin {
516 dst: usize,
517 input: usize,
518 },
519 Cos {
520 dst: usize,
521 input: usize,
522 },
523 Log {
524 dst: usize,
525 input: usize,
526 },
527 Cis {
528 dst: usize,
529 input: usize,
530 },
531}
532
533#[derive(Default)]
534struct ExpressionProgramBuilder {
535 ops: Vec<ExpressionOp>,
536 next_slot: usize,
537}
538
539impl ExpressionProgramBuilder {
540 fn alloc_slot(&mut self) -> usize {
541 let slot = self.next_slot;
542 self.next_slot += 1;
543 slot
544 }
545
546 fn build(self, root: usize) -> ExpressionProgram {
547 ExpressionProgram {
548 ops: self.ops,
549 slot_count: self.next_slot,
550 root_slot: root,
551 }
552 }
553
554 fn emit(&mut self, op: ExpressionOp) {
555 self.ops.push(op);
556 }
557
558 fn compile(&mut self, node: &ExpressionNode) -> usize {
559 match node {
560 ExpressionNode::Zero => {
561 let dst = self.alloc_slot();
562 self.emit(ExpressionOp::LoadZero { dst });
563 dst
564 }
565 ExpressionNode::One => {
566 let dst = self.alloc_slot();
567 self.emit(ExpressionOp::LoadOne { dst });
568 dst
569 }
570 ExpressionNode::Constant(value) => {
571 let dst = self.alloc_slot();
572 self.emit(ExpressionOp::LoadConstant { dst, value: *value });
573 dst
574 }
575 ExpressionNode::ComplexConstant(value) => {
576 let dst = self.alloc_slot();
577 self.emit(ExpressionOp::LoadComplexConstant { dst, value: *value });
578 dst
579 }
580 ExpressionNode::Amp(idx) => {
581 let dst = self.alloc_slot();
582 self.emit(ExpressionOp::LoadAmp { dst, amp_idx: *idx });
583 dst
584 }
585 ExpressionNode::Add(a, b) => {
586 let left = self.compile(a);
587 let right = self.compile(b);
588 let dst = self.alloc_slot();
589 self.emit(ExpressionOp::Add { dst, left, right });
590 dst
591 }
592 ExpressionNode::Sub(a, b) => {
593 let left = self.compile(a);
594 let right = self.compile(b);
595 let dst = self.alloc_slot();
596 self.emit(ExpressionOp::Sub { dst, left, right });
597 dst
598 }
599 ExpressionNode::Mul(a, b) => {
600 let left = self.compile(a);
601 let right = self.compile(b);
602 let dst = self.alloc_slot();
603 self.emit(ExpressionOp::Mul { dst, left, right });
604 dst
605 }
606 ExpressionNode::Div(a, b) => {
607 let left = self.compile(a);
608 let right = self.compile(b);
609 let dst = self.alloc_slot();
610 self.emit(ExpressionOp::Div { dst, left, right });
611 dst
612 }
613 ExpressionNode::Neg(a) => {
614 let input = self.compile(a);
615 let dst = self.alloc_slot();
616 self.emit(ExpressionOp::Neg { dst, input });
617 dst
618 }
619 ExpressionNode::Real(a) => {
620 let input = self.compile(a);
621 let dst = self.alloc_slot();
622 self.emit(ExpressionOp::Real { dst, input });
623 dst
624 }
625 ExpressionNode::Imag(a) => {
626 let input = self.compile(a);
627 let dst = self.alloc_slot();
628 self.emit(ExpressionOp::Imag { dst, input });
629 dst
630 }
631 ExpressionNode::Conj(a) => {
632 let input = self.compile(a);
633 let dst = self.alloc_slot();
634 self.emit(ExpressionOp::Conj { dst, input });
635 dst
636 }
637 ExpressionNode::NormSqr(a) => {
638 let input = self.compile(a);
639 let dst = self.alloc_slot();
640 self.emit(ExpressionOp::NormSqr { dst, input });
641 dst
642 }
643 ExpressionNode::Sqrt(a) => {
644 let input = self.compile(a);
645 let dst = self.alloc_slot();
646 self.emit(ExpressionOp::Sqrt { dst, input });
647 dst
648 }
649 ExpressionNode::Pow(a, b) => {
650 let value = self.compile(a);
651 let power = self.compile(b);
652 let dst = self.alloc_slot();
653 self.emit(ExpressionOp::Pow { dst, value, power });
654 dst
655 }
656 ExpressionNode::PowI(a, power) => {
657 let input = self.compile(a);
658 let dst = self.alloc_slot();
659 self.emit(ExpressionOp::PowI {
660 dst,
661 input,
662 power: *power,
663 });
664 dst
665 }
666 ExpressionNode::PowF(a, power) => {
667 let input = self.compile(a);
668 let dst = self.alloc_slot();
669 self.emit(ExpressionOp::PowF {
670 dst,
671 input,
672 power: *power,
673 });
674 dst
675 }
676 ExpressionNode::Exp(a) => {
677 let input = self.compile(a);
678 let dst = self.alloc_slot();
679 self.emit(ExpressionOp::Exp { dst, input });
680 dst
681 }
682 ExpressionNode::Sin(a) => {
683 let input = self.compile(a);
684 let dst = self.alloc_slot();
685 self.emit(ExpressionOp::Sin { dst, input });
686 dst
687 }
688 ExpressionNode::Cos(a) => {
689 let input = self.compile(a);
690 let dst = self.alloc_slot();
691 self.emit(ExpressionOp::Cos { dst, input });
692 dst
693 }
694 ExpressionNode::Log(a) => {
695 let input = self.compile(a);
696 let dst = self.alloc_slot();
697 self.emit(ExpressionOp::Log { dst, input });
698 dst
699 }
700 ExpressionNode::Cis(a) => {
701 let input = self.compile(a);
702 let dst = self.alloc_slot();
703 self.emit(ExpressionOp::Cis { dst, input });
704 dst
705 }
706 }
707 }
708}
709
710impl ExpressionProgram {
711 fn from_node(node: &ExpressionNode) -> Self {
712 let mut builder = ExpressionProgramBuilder::default();
713 let root = builder.compile(node);
714 builder.build(root)
715 }
716
717 fn fill_values(&self, amplitude_values: &[Complex64], slots: &mut [Complex64]) {
718 debug_assert!(slots.len() >= self.slot_count);
719 for op in &self.ops {
720 match *op {
721 ExpressionOp::LoadZero { dst } => slots[dst] = Complex64::ZERO,
722 ExpressionOp::LoadOne { dst } => slots[dst] = Complex64::ONE,
723 ExpressionOp::LoadConstant { dst, value } => slots[dst] = Complex64::from(value),
724 ExpressionOp::LoadComplexConstant { dst, value } => slots[dst] = value,
725 ExpressionOp::LoadAmp { dst, amp_idx } => {
726 slots[dst] = amplitude_values.get(amp_idx).copied().unwrap_or_default();
727 }
728 ExpressionOp::Add { dst, left, right } => {
729 slots[dst] = slots[left] + slots[right];
730 }
731 ExpressionOp::Sub { dst, left, right } => {
732 slots[dst] = slots[left] - slots[right];
733 }
734 ExpressionOp::Mul { dst, left, right } => {
735 slots[dst] = slots[left] * slots[right];
736 }
737 ExpressionOp::Div { dst, left, right } => {
738 slots[dst] = slots[left] / slots[right];
739 }
740 ExpressionOp::Neg { dst, input } => {
741 slots[dst] = -slots[input];
742 }
743 ExpressionOp::Real { dst, input } => {
744 slots[dst] = Complex64::new(slots[input].re, 0.0);
745 }
746 ExpressionOp::Imag { dst, input } => {
747 slots[dst] = Complex64::new(slots[input].im, 0.0);
748 }
749 ExpressionOp::Conj { dst, input } => {
750 slots[dst] = slots[input].conj();
751 }
752 ExpressionOp::NormSqr { dst, input } => {
753 slots[dst] = Complex64::new(slots[input].norm_sqr(), 0.0);
754 }
755 ExpressionOp::Sqrt { dst, input } => {
756 slots[dst] = slots[input].sqrt();
757 }
758 ExpressionOp::Pow { dst, value, power } => {
759 slots[dst] = slots[value].powc(slots[power]);
760 }
761 ExpressionOp::PowI { dst, input, power } => {
762 slots[dst] = slots[input].powi(power);
763 }
764 ExpressionOp::PowF { dst, input, power } => {
765 slots[dst] = slots[input].powc(Complex64::new(power, 0.0));
766 }
767 ExpressionOp::Exp { dst, input } => {
768 slots[dst] = slots[input].exp();
769 }
770 ExpressionOp::Sin { dst, input } => {
771 slots[dst] = slots[input].sin();
772 }
773 ExpressionOp::Cos { dst, input } => {
774 slots[dst] = slots[input].cos();
775 }
776 ExpressionOp::Log { dst, input } => {
777 slots[dst] = slots[input].ln();
778 }
779 ExpressionOp::Cis { dst, input } => {
780 slots[dst] = (Complex64::new(0.0, 1.0) * slots[input]).exp();
781 }
782 }
783 }
784 }
785
786 fn evaluate_into(&self, amplitude_values: &[Complex64], slots: &mut [Complex64]) -> Complex64 {
787 if self.slot_count == 0 {
788 return Complex64::ZERO;
789 }
790 self.fill_values(amplitude_values, slots);
791 slots[self.root_slot]
792 }
793
794 pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
795 if self.slot_count == 0 {
796 return Complex64::ZERO;
797 }
798 let mut slots = vec![Complex64::ZERO; self.slot_count];
799 self.evaluate_into(amplitude_values, &mut slots)
800 }
801
802 pub fn evaluate_gradient_into(
803 &self,
804 amplitude_values: &[Complex64],
805 gradient_values: &[DVector<Complex64>],
806 value_slots: &mut [Complex64],
807 gradient_slots: &mut [DVector<Complex64>],
808 ) -> DVector<Complex64> {
809 if self.slot_count == 0 {
810 let dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
811 return DVector::zeros(dim);
812 }
813 self.fill_values(amplitude_values, value_slots);
814 self.fill_gradients(gradient_values, value_slots, gradient_slots);
815 gradient_slots[self.root_slot].clone()
816 }
817
818 pub fn evaluate_gradient(
819 &self,
820 amplitude_values: &[Complex64],
821 gradient_values: &[DVector<Complex64>],
822 ) -> DVector<Complex64> {
823 let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
824 let mut value_slots = vec![Complex64::ZERO; self.slot_count];
825 let mut gradient_slots: Vec<DVector<Complex64>> = (0..self.slot_count)
826 .map(|_| DVector::zeros(grad_dim))
827 .collect();
828 self.evaluate_gradient_into(
829 amplitude_values,
830 gradient_values,
831 &mut value_slots,
832 &mut gradient_slots,
833 )
834 }
835
836 fn fill_gradients(
837 &self,
838 amplitude_gradients: &[DVector<Complex64>],
839 values: &[Complex64],
840 gradients: &mut [DVector<Complex64>],
841 ) {
842 debug_assert!(gradients.len() >= self.slot_count);
843 debug_assert!(values.len() >= self.slot_count);
844 fn borrow_dst(
845 gradients: &mut [DVector<Complex64>],
846 dst: usize,
847 ) -> (&[DVector<Complex64>], &mut DVector<Complex64>) {
848 let (before, tail) = gradients.split_at_mut(dst);
849 let (dst_ref, _) = tail.split_first_mut().expect("dst slot should exist");
850 (before, dst_ref)
851 }
852 for op in &self.ops {
853 match *op {
854 ExpressionOp::LoadZero { dst }
855 | ExpressionOp::LoadOne { dst }
856 | ExpressionOp::LoadConstant { dst, .. }
857 | ExpressionOp::LoadComplexConstant { dst, .. } => {
858 let (_, dst_grad) = borrow_dst(gradients, dst);
859 for item in dst_grad.iter_mut() {
860 *item = Complex64::ZERO;
861 }
862 }
863 ExpressionOp::LoadAmp { dst, amp_idx } => {
864 let (_, dst_grad) = borrow_dst(gradients, dst);
865 if let Some(source) = amplitude_gradients.get(amp_idx) {
866 dst_grad.clone_from(source);
867 } else {
868 for item in dst_grad.iter_mut() {
869 *item = Complex64::ZERO;
870 }
871 }
872 }
873 ExpressionOp::Add { dst, left, right } => {
874 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
875 dst_grad.clone_from(&before_dst[left]);
876 for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
877 {
878 *dst_item += *right_item;
879 }
880 }
881 ExpressionOp::Sub { dst, left, right } => {
882 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
883 dst_grad.clone_from(&before_dst[left]);
884 for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
885 {
886 *dst_item -= *right_item;
887 }
888 }
889 ExpressionOp::Mul { dst, left, right } => {
890 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
891 let f_left = values[left];
892 let f_right = values[right];
893 dst_grad.clone_from(&before_dst[right]);
894 for item in dst_grad.iter_mut() {
895 *item *= f_left;
896 }
897 for (dst_item, left_item) in dst_grad.iter_mut().zip(before_dst[left].iter()) {
898 *dst_item += *left_item * f_right;
899 }
900 }
901 ExpressionOp::Div { dst, left, right } => {
902 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
903 let f_left = values[left];
904 let f_right = values[right];
905 let denom = f_right * f_right;
906 dst_grad.clone_from(&before_dst[left]);
907 for item in dst_grad.iter_mut() {
908 *item *= f_right;
909 }
910 for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
911 {
912 *dst_item -= *right_item * f_left;
913 }
914 for item in dst_grad.iter_mut() {
915 *item /= denom;
916 }
917 }
918 ExpressionOp::Neg { dst, input } => {
919 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
920 dst_grad.clone_from(&before_dst[input]);
921 for item in dst_grad.iter_mut() {
922 *item = -*item;
923 }
924 }
925 ExpressionOp::Real { dst, input } => {
926 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
927 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
928 {
929 *dst_item = Complex64::new(input_item.re, 0.0);
930 }
931 }
932 ExpressionOp::Imag { dst, input } => {
933 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
934 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
935 {
936 *dst_item = Complex64::new(input_item.im, 0.0);
937 }
938 }
939 ExpressionOp::Conj { dst, input } => {
940 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
941 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
942 {
943 *dst_item = input_item.conj();
944 }
945 }
946 ExpressionOp::NormSqr { dst, input } => {
947 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
948 let conj_value = values[input].conj();
949 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
950 {
951 *dst_item = Complex64::new(2.0 * (*input_item * conj_value).re, 0.0);
952 }
953 }
954 ExpressionOp::Sqrt { dst, input } => {
955 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
956 let factor = Complex64::new(0.5, 0.0) / values[input].sqrt();
957 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
958 {
959 *dst_item = *input_item * factor;
960 }
961 }
962 ExpressionOp::Pow { dst, value, power } => {
963 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
964 let base = values[value];
965 let exponent = values[power];
966 let output = values[dst];
967 for ((dst_item, value_item), power_item) in dst_grad
968 .iter_mut()
969 .zip(before_dst[value].iter())
970 .zip(before_dst[power].iter())
971 {
972 *dst_item =
973 output * (*power_item * base.ln() + exponent * *value_item / base);
974 }
975 }
976 ExpressionOp::PowI { dst, input, power } => {
977 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
978 let factor = match power {
979 0 => Complex64::ZERO,
980 1 => Complex64::ONE,
981 _ => {
982 let base = values[input];
983 let multiplier = Complex64::new(power as f64, 0.0);
984 if let Some(derivative_power) = power.checked_sub(1) {
985 multiplier * base.powi(derivative_power)
986 } else {
987 multiplier * base.powi(power) / base
988 }
989 }
990 };
991 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
992 {
993 *dst_item = *input_item * factor;
994 }
995 }
996 ExpressionOp::PowF { dst, input, power } => {
997 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
998 let factor = if power == 0.0 {
999 Complex64::ZERO
1000 } else {
1001 Complex64::new(power, 0.0)
1002 * values[input].powc(Complex64::new(power - 1.0, 0.0))
1003 };
1004 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1005 {
1006 *dst_item = *input_item * factor;
1007 }
1008 }
1009 ExpressionOp::Exp { dst, input } => {
1010 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1011 let output = values[dst];
1012 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1013 {
1014 *dst_item = *input_item * output;
1015 }
1016 }
1017 ExpressionOp::Sin { dst, input } => {
1018 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1019 let factor = values[input].cos();
1020 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1021 {
1022 *dst_item = *input_item * factor;
1023 }
1024 }
1025 ExpressionOp::Cos { dst, input } => {
1026 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1027 let factor = -values[input].sin();
1028 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1029 {
1030 *dst_item = *input_item * factor;
1031 }
1032 }
1033 ExpressionOp::Log { dst, input } => {
1034 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1035 let factor = Complex64::ONE / values[input];
1036 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1037 {
1038 *dst_item = *input_item * factor;
1039 }
1040 }
1041 ExpressionOp::Cis { dst, input } => {
1042 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1043 let factor = Complex64::new(0.0, 1.0) * values[dst];
1044 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1045 {
1046 *dst_item = *input_item * factor;
1047 }
1048 }
1049 }
1050 }
1051 }
1052}
1053
1054impl ExpressionNode {
1055 fn remap(&self, mapping: &[usize]) -> Self {
1056 match self {
1057 Self::Amp(idx) => Self::Amp(mapping[*idx]),
1058 Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1059 Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1060 Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1061 Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1062 Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
1063 Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
1064 Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
1065 Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
1066 Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
1067 Self::Zero => Self::Zero,
1068 Self::One => Self::One,
1069 Self::Constant(v) => Self::Constant(*v),
1070 Self::ComplexConstant(v) => Self::ComplexConstant(*v),
1071 Self::Sqrt(a) => Self::Sqrt(Box::new(a.remap(mapping))),
1072 Self::Pow(a, b) => Self::Pow(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1073 Self::PowI(a, power) => Self::PowI(Box::new(a.remap(mapping)), *power),
1074 Self::PowF(a, power) => Self::PowF(Box::new(a.remap(mapping)), *power),
1075 Self::Exp(a) => Self::Exp(Box::new(a.remap(mapping))),
1076 Self::Sin(a) => Self::Sin(Box::new(a.remap(mapping))),
1077 Self::Cos(a) => Self::Cos(Box::new(a.remap(mapping))),
1078 Self::Log(a) => Self::Log(Box::new(a.remap(mapping))),
1079 Self::Cis(a) => Self::Cis(Box::new(a.remap(mapping))),
1080 }
1081 }
1082
1083 fn program(&self) -> ExpressionProgram {
1084 ExpressionProgram::from_node(self)
1085 }
1086
1087 pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
1089 self.program().evaluate(amplitude_values)
1090 }
1091
1092 pub fn evaluate_gradient(
1094 &self,
1095 amplitude_values: &[Complex64],
1096 gradient_values: &[DVector<Complex64>],
1097 ) -> DVector<Complex64> {
1098 self.program()
1099 .evaluate_gradient(amplitude_values, gradient_values)
1100 }
1101}
1102
1103impl From<f64> for Expression {
1104 fn from(value: f64) -> Self {
1105 if value == 0.0 {
1106 Self {
1107 registry: ExpressionRegistry::default(),
1108 tree: ExpressionNode::Zero,
1109 }
1110 } else if value == 1.0 {
1111 Self {
1112 registry: ExpressionRegistry::default(),
1113 tree: ExpressionNode::One,
1114 }
1115 } else {
1116 Self {
1117 registry: ExpressionRegistry::default(),
1118 tree: ExpressionNode::Constant(value),
1119 }
1120 }
1121 }
1122}
1123impl From<&f64> for Expression {
1124 fn from(value: &f64) -> Self {
1125 (*value).into()
1126 }
1127}
1128impl From<Complex64> for Expression {
1129 fn from(value: Complex64) -> Self {
1130 if value == Complex64::ZERO {
1131 Self {
1132 registry: ExpressionRegistry::default(),
1133 tree: ExpressionNode::Zero,
1134 }
1135 } else if value == Complex64::ONE {
1136 Self {
1137 registry: ExpressionRegistry::default(),
1138 tree: ExpressionNode::One,
1139 }
1140 } else {
1141 Self {
1142 registry: ExpressionRegistry::default(),
1143 tree: ExpressionNode::ComplexConstant(value),
1144 }
1145 }
1146 }
1147}
1148impl From<&Complex64> for Expression {
1149 fn from(value: &Complex64) -> Self {
1150 (*value).into()
1151 }
1152}
1153
1154impl Expression {
1155 pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
1157 let registry = ExpressionRegistry::singleton(amplitude)?;
1158 Ok(Self {
1159 tree: ExpressionNode::Amp(0),
1160 registry,
1161 })
1162 }
1163
1164 pub fn zero() -> Self {
1166 Self {
1167 registry: ExpressionRegistry::default(),
1168 tree: ExpressionNode::Zero,
1169 }
1170 }
1171
1172 pub fn one() -> Self {
1174 Self {
1175 registry: ExpressionRegistry::default(),
1176 tree: ExpressionNode::One,
1177 }
1178 }
1179
1180 fn binary_op(
1181 a: &Expression,
1182 b: &Expression,
1183 build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
1184 ) -> Expression {
1185 let (registry, left_map, right_map) = a
1186 .registry
1187 .merge(&b.registry)
1188 .expect("merging expression registries should not fail");
1189 let left_tree = a.tree.remap(&left_map);
1190 let right_tree = b.tree.remap(&right_map);
1191 Expression {
1192 registry,
1193 tree: build(Box::new(left_tree), Box::new(right_tree)),
1194 }
1195 }
1196
1197 fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
1198 Expression {
1199 registry: a.registry.clone(),
1200 tree: build(Box::new(a.tree.clone())),
1201 }
1202 }
1203
1204 pub fn parameters(&self) -> ParameterMap {
1206 self.registry.resources.parameters()
1207 }
1208
1209 pub fn n_free(&self) -> usize {
1211 self.registry.resources.n_free_parameters()
1212 }
1213
1214 pub fn n_fixed(&self) -> usize {
1216 self.registry.resources.n_fixed_parameters()
1217 }
1218
1219 pub fn n_parameters(&self) -> usize {
1221 self.registry.resources.n_parameters()
1222 }
1223
1224 pub fn compiled_expression(&self) -> CompiledExpression {
1230 let active_amplitudes = vec![true; self.registry.amplitude_use_sites.len()];
1231 let amplitude_dependencies = self
1232 .registry
1233 .amplitude_use_sites
1234 .iter()
1235 .map(|use_site| {
1236 ir::DependenceClass::from(
1237 self.registry.amplitudes[use_site.amplitude_index].dependence_hint(),
1238 )
1239 })
1240 .collect::<Vec<_>>();
1241 let amplitude_realness = self
1242 .registry
1243 .amplitude_use_sites
1244 .iter()
1245 .map(|use_site| self.registry.amplitudes[use_site.amplitude_index].real_valued_hint())
1246 .collect::<Vec<_>>();
1247 let expression_ir = ir::compile_expression_ir_with_real_hints(
1248 &self.tree,
1249 &active_amplitudes,
1250 &litude_dependencies,
1251 &litude_realness,
1252 );
1253 CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
1254 }
1255
1256 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
1258 self.registry.resources.fix_parameter(name, value)
1259 }
1260
1261 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
1263 self.registry.resources.free_parameter(name)
1264 }
1265
1266 pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
1268 self.registry.resources.rename_parameter(old, new)
1269 }
1270
1271 pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
1273 self.registry.resources.rename_parameters(mapping)
1274 }
1275
1276 pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
1278 let mut resources = self.registry.resources.clone();
1279 let metadata = dataset.metadata();
1280 resources.reserve_cache(dataset.n_events_local());
1281 resources.refresh_active_indices();
1282 let parameter_map = resources.parameter_map.clone();
1283 let mut amplitudes: Vec<Box<dyn Amplitude>> = self
1284 .registry
1285 .amplitudes
1286 .iter()
1287 .map(|amp| dyn_clone::clone_box(&**amp))
1288 .collect();
1289 {
1290 for amplitude in amplitudes.iter_mut() {
1291 amplitude.bind(metadata)?;
1292 amplitude.precompute_all(dataset, &mut resources);
1293 }
1294 }
1295 let ir_compile_start = Instant::now();
1296 let expression_ir = {
1297 let mut active_amplitudes = vec![false; self.registry.amplitude_use_sites.len()];
1298 for &index in resources.active_indices() {
1299 active_amplitudes[index] = true;
1300 }
1301 let amplitude_dependencies = self
1302 .registry
1303 .amplitude_use_sites
1304 .iter()
1305 .map(|use_site| {
1306 ir::DependenceClass::from(
1307 amplitudes[use_site.amplitude_index].dependence_hint(),
1308 )
1309 })
1310 .collect::<Vec<_>>();
1311 let amplitude_realness = self
1312 .registry
1313 .amplitude_use_sites
1314 .iter()
1315 .map(|use_site| amplitudes[use_site.amplitude_index].real_valued_hint())
1316 .collect::<Vec<_>>();
1317 ir::compile_expression_ir_with_real_hints(
1318 &self.tree,
1319 &active_amplitudes,
1320 &litude_dependencies,
1321 &litude_realness,
1322 )
1323 };
1324 let initial_ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
1325 let cached_integrals_start = Instant::now();
1326 let cached_integrals = Evaluator::precompute_cached_integrals_at_load(
1327 &expression_ir,
1328 &litudes,
1329 &self.registry.amplitude_use_sites,
1330 &resources,
1331 dataset,
1332 parameter_map.free().len(),
1333 )?;
1334 let initial_cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
1335 let lowering_start = Instant::now();
1336 let lowered_artifacts = Arc::new(Evaluator::lower_expression_runtime_artifacts(
1337 &expression_ir,
1338 &cached_integrals,
1339 )?);
1340 let initial_lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
1341 let execution_sets = expression_ir.normalization_execution_sets().clone();
1342 let cached_integral_key =
1343 Evaluator::cached_integral_cache_key(resources.active.clone(), dataset);
1344 let cached_integral_state = Arc::new(CachedIntegralCacheState {
1345 key: cached_integral_key.clone(),
1346 expression_ir,
1347 values: cached_integrals,
1348 execution_sets,
1349 });
1350 let specialization_state = ExpressionSpecializationState {
1351 cached_integrals: cached_integral_state.clone(),
1352 lowered_artifacts: lowered_artifacts.clone(),
1353 };
1354 let specialization_cache = HashMap::from([(cached_integral_key, specialization_state)]);
1355 let lowered_artifact_cache =
1356 HashMap::from([(resources.active.clone(), lowered_artifacts.clone())]);
1357 Ok(Evaluator {
1358 amplitudes,
1359 amplitude_use_sites: self.registry.amplitude_use_sites.clone(),
1360 resources: Arc::new(RwLock::new(resources)),
1361 dataset: dataset.clone(),
1362 expression: self.tree.clone(),
1363 ir_planning: ExpressionIrPlanningState {
1364 cached_integrals: Arc::new(RwLock::new(Some(cached_integral_state))),
1365 specialization_cache: Arc::new(RwLock::new(specialization_cache)),
1366 specialization_metrics: Arc::new(RwLock::new(ExpressionSpecializationMetrics {
1367 cache_hits: 0,
1368 cache_misses: 1,
1369 })),
1370 lowered_artifact_cache: Arc::new(RwLock::new(lowered_artifact_cache)),
1371 active_lowered_artifacts: Arc::new(RwLock::new(Some(lowered_artifacts.clone()))),
1372 specialization_status: Arc::new(RwLock::new(Some(
1373 ExpressionSpecializationStatus {
1374 origin: ExpressionSpecializationOrigin::InitialLoad,
1375 },
1376 ))),
1377 compile_metrics: Arc::new(RwLock::new(ExpressionCompileMetrics {
1378 initial_ir_compile_nanos,
1379 initial_cached_integrals_nanos,
1380 initial_lowering_nanos,
1381 specialization_lowering_cache_misses: 1,
1382 ..Default::default()
1383 })),
1384 },
1385 registry: self.registry.clone(),
1386 })
1387 }
1388
1389 pub fn real(&self) -> Self {
1391 Self::unary_op(self, ExpressionNode::Real)
1392 }
1393 pub fn imag(&self) -> Self {
1395 Self::unary_op(self, ExpressionNode::Imag)
1396 }
1397 pub fn conj(&self) -> Self {
1399 Self::unary_op(self, ExpressionNode::Conj)
1400 }
1401 pub fn norm_sqr(&self) -> Self {
1403 Self::unary_op(self, ExpressionNode::NormSqr)
1404 }
1405 pub fn sqrt(&self) -> Self {
1407 Self::unary_op(self, ExpressionNode::Sqrt)
1408 }
1409 pub fn pow(&self, power: &Expression) -> Self {
1411 Self::binary_op(self, power, ExpressionNode::Pow)
1412 }
1413 pub fn powi(&self, power: i32) -> Self {
1415 Self::unary_op(self, |input| ExpressionNode::PowI(input, power))
1416 }
1417 pub fn powf(&self, power: f64) -> Self {
1419 Self::unary_op(self, |input| ExpressionNode::PowF(input, power))
1420 }
1421 pub fn exp(&self) -> Self {
1423 Self::unary_op(self, ExpressionNode::Exp)
1424 }
1425 pub fn sin(&self) -> Self {
1427 Self::unary_op(self, ExpressionNode::Sin)
1428 }
1429 pub fn cos(&self) -> Self {
1431 Self::unary_op(self, ExpressionNode::Cos)
1432 }
1433 pub fn log(&self) -> Self {
1435 Self::unary_op(self, ExpressionNode::Log)
1436 }
1437 pub fn cis(&self) -> Self {
1439 Self::unary_op(self, ExpressionNode::Cis)
1440 }
1441
1442 fn write_tree(
1444 &self,
1445 t: &ExpressionNode,
1446 f: &mut std::fmt::Formatter<'_>,
1447 parent_prefix: &str,
1448 immediate_prefix: &str,
1449 parent_suffix: &str,
1450 ) -> std::fmt::Result {
1451 let display_string = match t {
1452 ExpressionNode::Amp(idx) => {
1453 let name = self
1454 .registry
1455 .amplitude_names
1456 .get(*idx)
1457 .cloned()
1458 .unwrap_or_else(|| "<unregistered>".to_string());
1459 format!("{name}(id={idx})")
1460 }
1461 ExpressionNode::Add(_, _) => "+".to_string(),
1462 ExpressionNode::Sub(_, _) => "-".to_string(),
1463 ExpressionNode::Mul(_, _) => "×".to_string(),
1464 ExpressionNode::Div(_, _) => "÷".to_string(),
1465 ExpressionNode::Neg(_) => "-".to_string(),
1466 ExpressionNode::Real(_) => "Re".to_string(),
1467 ExpressionNode::Imag(_) => "Im".to_string(),
1468 ExpressionNode::Conj(_) => "*".to_string(),
1469 ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
1470 ExpressionNode::Zero => "0 (exact)".to_string(),
1471 ExpressionNode::One => "1 (exact)".to_string(),
1472 ExpressionNode::Constant(v) => v.to_string(),
1473 ExpressionNode::ComplexConstant(v) => v.to_string(),
1474 ExpressionNode::Sqrt(_) => "Sqrt".to_string(),
1475 ExpressionNode::Pow(_, _) => "Pow".to_string(),
1476 ExpressionNode::PowI(_, power) => format!("PowI({power})"),
1477 ExpressionNode::PowF(_, power) => format!("PowF({power})"),
1478 ExpressionNode::Exp(_) => "Exp".to_string(),
1479 ExpressionNode::Sin(_) => "Sin".to_string(),
1480 ExpressionNode::Cos(_) => "Cos".to_string(),
1481 ExpressionNode::Log(_) => "Log".to_string(),
1482 ExpressionNode::Cis(_) => "Cis".to_string(),
1483 };
1484 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
1485 match t {
1486 ExpressionNode::Amp(_)
1487 | ExpressionNode::Zero
1488 | ExpressionNode::One
1489 | ExpressionNode::Constant(_)
1490 | ExpressionNode::ComplexConstant(_) => {}
1491 ExpressionNode::Add(a, b)
1492 | ExpressionNode::Sub(a, b)
1493 | ExpressionNode::Mul(a, b)
1494 | ExpressionNode::Div(a, b)
1495 | ExpressionNode::Pow(a, b) => {
1496 let terms = [a, b];
1497 let mut it = terms.iter().peekable();
1498 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1499 while let Some(child) = it.next() {
1500 match it.peek() {
1501 Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│ "),
1502 None => self.write_tree(child, f, &child_prefix, "└─ ", " "),
1503 }?;
1504 }
1505 }
1506 ExpressionNode::Neg(a)
1507 | ExpressionNode::Real(a)
1508 | ExpressionNode::Imag(a)
1509 | ExpressionNode::Conj(a)
1510 | ExpressionNode::NormSqr(a)
1511 | ExpressionNode::Sqrt(a)
1512 | ExpressionNode::PowI(a, _)
1513 | ExpressionNode::PowF(a, _)
1514 | ExpressionNode::Exp(a)
1515 | ExpressionNode::Sin(a)
1516 | ExpressionNode::Cos(a)
1517 | ExpressionNode::Log(a)
1518 | ExpressionNode::Cis(a) => {
1519 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1520 self.write_tree(a, f, &child_prefix, "└─ ", " ")?;
1521 }
1522 }
1523 Ok(())
1524 }
1525}
1526
1527impl Debug for Expression {
1528 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1529 self.write_tree(&self.tree, f, "", "", "")
1530 }
1531}
1532
1533impl Display for Expression {
1534 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1535 self.write_tree(&self.tree, f, "", "", "")
1536 }
1537}
1538
1539#[rustfmt::skip]
1540impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
1541 Expression::binary_op(a, b, ExpressionNode::Add)
1542});
1543#[rustfmt::skip]
1544impl_op_ex!(+ |a: &Expression, b: &f64| -> Expression {
1545 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
1546});
1547#[rustfmt::skip]
1548impl_op_ex!(+ |a: &f64, b: &Expression| -> Expression {
1549 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
1550});
1551#[rustfmt::skip]
1552impl_op_ex!(+ |a: &Expression, b: &Complex64| -> Expression {
1553 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
1554});
1555#[rustfmt::skip]
1556impl_op_ex!(+ |a: &Complex64, b: &Expression| -> Expression {
1557 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
1558});
1559
1560#[rustfmt::skip]
1561impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
1562 Expression::binary_op(a, b, ExpressionNode::Sub)
1563});
1564#[rustfmt::skip]
1565impl_op_ex!(- |a: &Expression, b: &f64| -> Expression {
1566 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
1567});
1568#[rustfmt::skip]
1569impl_op_ex!(- |a: &f64, b: &Expression| -> Expression {
1570 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
1571});
1572#[rustfmt::skip]
1573impl_op_ex!(- |a: &Expression, b: &Complex64| -> Expression {
1574 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
1575});
1576#[rustfmt::skip]
1577impl_op_ex!(- |a: &Complex64, b: &Expression| -> Expression {
1578 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
1579});
1580
1581#[rustfmt::skip]
1582impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
1583 Expression::binary_op(a, b, ExpressionNode::Mul)
1584});
1585#[rustfmt::skip]
1586impl_op_ex!(* |a: &Expression, b: &f64| -> Expression {
1587 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
1588});
1589#[rustfmt::skip]
1590impl_op_ex!(* |a: &f64, b: &Expression| -> Expression {
1591 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
1592});
1593#[rustfmt::skip]
1594impl_op_ex!(* |a: &Expression, b: &Complex64| -> Expression {
1595 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
1596});
1597#[rustfmt::skip]
1598impl_op_ex!(* |a: &Complex64, b: &Expression| -> Expression {
1599 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
1600});
1601
1602#[rustfmt::skip]
1603impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
1604 Expression::binary_op(a, b, ExpressionNode::Div)
1605});
1606#[rustfmt::skip]
1607impl_op_ex!(/ |a: &Expression, b: &f64| -> Expression {
1608 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
1609});
1610#[rustfmt::skip]
1611impl_op_ex!(/ |a: &f64, b: &Expression| -> Expression {
1612 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
1613});
1614#[rustfmt::skip]
1615impl_op_ex!(/ |a: &Expression, b: &Complex64| -> Expression {
1616 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
1617});
1618#[rustfmt::skip]
1619impl_op_ex!(/ |a: &Complex64, b: &Expression| -> Expression {
1620 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
1621});
1622
1623#[rustfmt::skip]
1624impl_op_ex!(- |a: &Expression| -> Expression {
1625 Expression::unary_op(a, ExpressionNode::Neg)
1626});
1627#[derive(Clone, Debug)]
1630#[doc(hidden)]
1631pub struct ExpressionValueProgramSnapshot {
1632 lowered_program: lowered::LoweredProgram,
1633}
1634
1635#[derive(Clone, Debug, PartialEq)]
1636pub enum CompiledExpressionNode {
1638 Constant(Complex64),
1640 Amplitude {
1642 index: usize,
1644 name: String,
1646 },
1647 Unary {
1649 op: String,
1651 input: usize,
1653 },
1654 Binary {
1656 op: String,
1658 left: usize,
1660 right: usize,
1662 },
1663}
1664
1665#[derive(Clone, Debug, PartialEq)]
1666pub struct CompiledExpression {
1672 nodes: Vec<CompiledExpressionNode>,
1673 root: usize,
1674}
1675
1676impl CompiledExpression {
1677 fn from_ir(ir: &ir::ExpressionIR, amplitude_names: &[String]) -> Self {
1678 let nodes = ir
1679 .nodes()
1680 .iter()
1681 .map(|node| match node {
1682 ir::IrNode::Constant(value) => CompiledExpressionNode::Constant(*value),
1683 ir::IrNode::Amp(index) => CompiledExpressionNode::Amplitude {
1684 index: *index,
1685 name: amplitude_names
1686 .get(*index)
1687 .cloned()
1688 .unwrap_or_else(|| "<unregistered>".to_string()),
1689 },
1690 ir::IrNode::Unary { op, input } => CompiledExpressionNode::Unary {
1691 op: compiled_unary_op_label(*op),
1692 input: *input,
1693 },
1694 ir::IrNode::Binary { op, left, right } => CompiledExpressionNode::Binary {
1695 op: compiled_binary_op_label(*op),
1696 left: *left,
1697 right: *right,
1698 },
1699 })
1700 .collect();
1701 Self {
1702 nodes,
1703 root: ir.root(),
1704 }
1705 }
1706
1707 pub fn nodes(&self) -> &[CompiledExpressionNode] {
1709 &self.nodes
1710 }
1711
1712 pub fn root(&self) -> usize {
1714 self.root
1715 }
1716
1717 fn node_label(&self, index: usize) -> String {
1718 let Some(node) = self.nodes.get(index) else {
1719 return format!("#{index} <missing>");
1720 };
1721 let label = match node {
1722 CompiledExpressionNode::Constant(value) => format!("const {value}"),
1723 CompiledExpressionNode::Amplitude { index, name } => {
1724 format!("{name}(id={index})")
1725 }
1726 CompiledExpressionNode::Unary { op, .. }
1727 | CompiledExpressionNode::Binary { op, .. } => op.clone(),
1728 };
1729 format!("#{index} {label}")
1730 }
1731
1732 fn write_tree(
1734 &self,
1735 index: usize,
1736 f: &mut std::fmt::Formatter<'_>,
1737 parent_prefix: &str,
1738 immediate_prefix: &str,
1739 parent_suffix: &str,
1740 expanded: &mut [bool],
1741 ) -> std::fmt::Result {
1742 let already_expanded = expanded.get(index).copied().unwrap_or(false);
1743 if let Some(slot) = expanded.get_mut(index) {
1744 *slot = true;
1745 }
1746 let ref_suffix = if already_expanded { " (ref)" } else { "" };
1747 writeln!(
1748 f,
1749 "{}{}{}{}",
1750 parent_prefix,
1751 immediate_prefix,
1752 self.node_label(index),
1753 ref_suffix
1754 )?;
1755 if already_expanded {
1756 return Ok(());
1757 }
1758 let Some(node) = self.nodes.get(index) else {
1759 return Ok(());
1760 };
1761 match node {
1762 CompiledExpressionNode::Constant(_) | CompiledExpressionNode::Amplitude { .. } => {}
1763 CompiledExpressionNode::Unary { input, .. } => {
1764 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1765 self.write_tree(*input, f, &child_prefix, "└─ ", " ", expanded)?;
1766 }
1767 CompiledExpressionNode::Binary { left, right, .. } => {
1768 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
1769 self.write_tree(*left, f, &child_prefix, "├─ ", "│ ", expanded)?;
1770 self.write_tree(*right, f, &child_prefix, "└─ ", " ", expanded)?;
1771 }
1772 }
1773 Ok(())
1774 }
1775}
1776
1777impl Display for CompiledExpression {
1778 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1779 if self.nodes.is_empty() {
1780 return writeln!(f, "<empty>");
1781 }
1782 let mut expanded = vec![false; self.nodes.len()];
1783 self.write_tree(self.root, f, "", "", "", &mut expanded)
1784 }
1785}
1786
1787fn compiled_unary_op_label(op: ir::IrUnaryOp) -> String {
1788 match op {
1789 ir::IrUnaryOp::Neg => "-".to_string(),
1790 ir::IrUnaryOp::Real => "Re".to_string(),
1791 ir::IrUnaryOp::Imag => "Im".to_string(),
1792 ir::IrUnaryOp::Conj => "*".to_string(),
1793 ir::IrUnaryOp::NormSqr => "NormSqr".to_string(),
1794 ir::IrUnaryOp::Sqrt => "Sqrt".to_string(),
1795 ir::IrUnaryOp::PowI(power) => format!("PowI({power})"),
1796 ir::IrUnaryOp::PowF(bits) => format!("PowF({})", f64::from_bits(bits)),
1797 ir::IrUnaryOp::Exp => "Exp".to_string(),
1798 ir::IrUnaryOp::Sin => "Sin".to_string(),
1799 ir::IrUnaryOp::Cos => "Cos".to_string(),
1800 ir::IrUnaryOp::Log => "Log".to_string(),
1801 ir::IrUnaryOp::Cis => "Cis".to_string(),
1802 }
1803}
1804
1805fn compiled_binary_op_label(op: ir::IrBinaryOp) -> String {
1806 match op {
1807 ir::IrBinaryOp::Add => "+".to_string(),
1808 ir::IrBinaryOp::Sub => "-".to_string(),
1809 ir::IrBinaryOp::Mul => "×".to_string(),
1810 ir::IrBinaryOp::Div => "÷".to_string(),
1811 ir::IrBinaryOp::Pow => "Pow".to_string(),
1812 }
1813}
1814#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1815pub enum ExpressionSpecializationOrigin {
1817 InitialLoad,
1819 CacheMissRebuild,
1821 CacheHitRestore,
1823}
1824#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1825pub struct ExpressionSpecializationStatus {
1827 pub origin: ExpressionSpecializationOrigin,
1829}
1830#[derive(Clone, Debug, PartialEq, Eq)]
1831pub struct ExpressionRuntimeDiagnostics {
1833 pub ir_planning_enabled: bool,
1835 pub lowered_value_program_present: bool,
1837 pub lowered_gradient_program_present: bool,
1839 pub lowered_value_gradient_program_present: bool,
1841 pub cached_parameter_factor_count: usize,
1843 pub lowered_cached_parameter_factor_count: usize,
1845 pub residual_runtime_present: bool,
1847 pub specialization_cache_entries: usize,
1849 pub lowered_artifact_cache_entries: usize,
1851 pub specialization_status: Option<ExpressionSpecializationStatus>,
1853}
1854#[derive(Clone)]
1855struct ExpressionIrPlanningState {
1862 cached_integrals: Arc<RwLock<Option<Arc<CachedIntegralCacheState>>>>,
1863 specialization_cache:
1864 Arc<RwLock<HashMap<CachedIntegralCacheKey, ExpressionSpecializationState>>>,
1865 specialization_metrics: Arc<RwLock<ExpressionSpecializationMetrics>>,
1866 lowered_artifact_cache: Arc<RwLock<HashMap<Vec<bool>, Arc<LoweredArtifactCacheState>>>>,
1867 active_lowered_artifacts: Arc<RwLock<Option<Arc<LoweredArtifactCacheState>>>>,
1868 specialization_status: Arc<RwLock<Option<ExpressionSpecializationStatus>>>,
1869 compile_metrics: Arc<RwLock<ExpressionCompileMetrics>>,
1870}
1871#[allow(missing_docs)]
1873#[derive(Clone)]
1874pub struct Evaluator {
1875 pub amplitudes: Vec<Box<dyn Amplitude>>,
1876 amplitude_use_sites: Vec<AmplitudeUseSite>,
1877 pub resources: Arc<RwLock<Resources>>,
1878 pub dataset: Arc<Dataset>,
1879 pub expression: ExpressionNode,
1880 ir_planning: ExpressionIrPlanningState,
1881 registry: ExpressionRegistry,
1882}
1883
1884#[allow(missing_docs)]
1885impl Evaluator {
1886 pub fn expression_specialization_metrics(&self) -> ExpressionSpecializationMetrics {
1888 *self.ir_planning.specialization_metrics.read()
1889 }
1890 pub fn reset_expression_specialization_metrics(&self) {
1892 *self.ir_planning.specialization_metrics.write() =
1893 ExpressionSpecializationMetrics::default();
1894 }
1895 pub fn expression_compile_metrics(&self) -> ExpressionCompileMetrics {
1897 *self.ir_planning.compile_metrics.read()
1898 }
1899 pub fn expression_runtime_diagnostics(&self) -> ExpressionRuntimeDiagnostics {
1901 let active_artifacts = self.active_lowered_artifacts();
1902 let cached_parameter_factor_count = self
1903 .ir_planning
1904 .cached_integrals
1905 .read()
1906 .as_ref()
1907 .map(|state| state.values.len())
1908 .unwrap_or(0);
1909 let lowered_cached_parameter_factor_count = active_artifacts
1910 .as_ref()
1911 .map(|artifacts| {
1912 artifacts
1913 .lowered_parameter_factors
1914 .iter()
1915 .filter(|factor| factor.is_some())
1916 .count()
1917 })
1918 .unwrap_or(0);
1919 let residual_runtime_present = active_artifacts
1920 .as_ref()
1921 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
1922 .is_some();
1923 ExpressionRuntimeDiagnostics {
1924 ir_planning_enabled: true,
1925 lowered_value_program_present: true,
1926 lowered_gradient_program_present: true,
1927 lowered_value_gradient_program_present: true,
1928 cached_parameter_factor_count,
1929 lowered_cached_parameter_factor_count,
1930 residual_runtime_present,
1931 specialization_cache_entries: self.ir_planning.specialization_cache.read().len(),
1932 lowered_artifact_cache_entries: self.ir_planning.lowered_artifact_cache.read().len(),
1933 specialization_status: *self.ir_planning.specialization_status.read(),
1934 }
1935 }
1936 pub fn reset_expression_compile_metrics(&self) {
1938 let mut metrics = self.ir_planning.compile_metrics.write();
1939 metrics.specialization_cache_hits = 0;
1940 metrics.specialization_cache_misses = 0;
1941 metrics.specialization_ir_compile_nanos = 0;
1942 metrics.specialization_cached_integrals_nanos = 0;
1943 metrics.specialization_lowering_nanos = 0;
1944 metrics.specialization_lowering_cache_hits = 0;
1945 metrics.specialization_lowering_cache_misses = 0;
1946 metrics.specialization_cache_restore_nanos = 0;
1947 }
1948 #[cfg(test)]
1949 fn expression_ir(&self) -> ir::ExpressionIR {
1950 self.ir_planning
1951 .cached_integrals
1952 .read()
1953 .as_ref()
1954 .map(|state| state.expression_ir.clone())
1955 .expect("cached integral state should exist for evaluator IR access")
1956 }
1957 fn lowered_runtime(&self) -> lowered::LoweredExpressionRuntime {
1958 self.active_lowered_artifacts()
1959 .expect("active lowered artifacts should exist for the current specialization")
1960 .lowered_runtime
1961 .clone()
1962 }
1963 fn active_lowered_artifacts(&self) -> Option<Arc<LoweredArtifactCacheState>> {
1964 self.ir_planning.active_lowered_artifacts.read().clone()
1965 }
1966 fn lowered_runtime_slot_count(&self) -> usize {
1967 let runtime = self.lowered_runtime();
1968 [
1969 runtime.value_program().scratch_slots(),
1970 runtime.gradient_program().scratch_slots(),
1971 runtime.value_gradient_program().scratch_slots(),
1972 ]
1973 .into_iter()
1974 .max()
1975 .unwrap_or(0)
1976 }
1977 fn lowered_value_runtime_slot_count(&self) -> usize {
1978 self.lowered_runtime().value_program().scratch_slots()
1979 }
1980
1981 #[doc(hidden)]
1982 pub fn expression_value_program_snapshot(&self) -> ExpressionValueProgramSnapshot {
1983 ExpressionValueProgramSnapshot {
1984 lowered_program: self.lowered_runtime().value_program().clone(),
1985 }
1986 }
1987
1988 #[doc(hidden)]
1989 pub fn expression_value_program_snapshot_for_active_mask(
1990 &self,
1991 active_mask: &[bool],
1992 ) -> LadduResult<ExpressionValueProgramSnapshot> {
1993 let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
1994 let lowered_program =
1995 lowered::LoweredProgram::from_ir_value_only(&expression_ir).map_err(|error| {
1996 LadduError::Custom(format!(
1997 "Failed to lower value-only active-mask runtime: {error:?}"
1998 ))
1999 })?;
2000 Ok(ExpressionValueProgramSnapshot { lowered_program })
2001 }
2002
2003 #[doc(hidden)]
2004 pub fn expression_value_program_snapshot_slot_count(
2005 &self,
2006 snapshot: &ExpressionValueProgramSnapshot,
2007 ) -> usize {
2008 let _ = self;
2009 snapshot.lowered_program.scratch_slots()
2010 }
2011
2012 pub fn compiled_expression(&self) -> CompiledExpression {
2015 let expression_ir = self.compile_expression_ir_for_active_mask(&self.active_mask());
2016 CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
2017 }
2018
2019 pub fn expression(&self) -> Expression {
2021 Expression {
2022 tree: self.expression.clone(),
2023 registry: self.registry.clone(),
2024 }
2025 }
2026 fn lowered_gradient_runtime_slot_count(&self) -> usize {
2027 self.lowered_runtime().gradient_program().scratch_slots()
2028 }
2029 fn lowered_value_gradient_runtime_slot_count(&self) -> usize {
2030 self.lowered_runtime()
2031 .value_gradient_program()
2032 .scratch_slots()
2033 }
2034
2035 fn expression_value_slot_count(&self) -> usize {
2036 self.lowered_value_runtime_slot_count()
2037 }
2038 fn expression_gradient_slot_count(&self) -> usize {
2039 self.lowered_gradient_runtime_slot_count()
2040 }
2041 fn expression_value_gradient_slot_count(&self) -> usize {
2042 self.lowered_value_gradient_runtime_slot_count()
2043 }
2044
2045 #[doc(hidden)]
2046 pub fn expression_value_gradient_slot_count_public(&self) -> usize {
2047 self.expression_value_gradient_slot_count()
2048 }
2049 #[cfg(test)]
2050 fn specialization_cache_len(&self) -> usize {
2051 self.ir_planning.specialization_cache.read().len()
2052 }
2053 #[cfg(test)]
2054 fn lowered_artifact_cache_len(&self) -> usize {
2055 self.ir_planning.lowered_artifact_cache.read().len()
2056 }
2057 fn install_expression_specialization(&self, specialization: &ExpressionSpecializationState) {
2058 debug_assert!(Self::lowered_artifact_signature_matches(
2059 &specialization.lowered_artifacts,
2060 &specialization.cached_integrals.values,
2061 ));
2062 *self.ir_planning.cached_integrals.write() = Some(specialization.cached_integrals.clone());
2063 *self.ir_planning.active_lowered_artifacts.write() =
2064 Some(specialization.lowered_artifacts.clone());
2065 debug_assert_eq!(
2066 self.active_lowered_artifacts()
2067 .as_ref()
2068 .map(|artifacts| Arc::ptr_eq(artifacts, &specialization.lowered_artifacts)),
2069 Some(true)
2070 );
2071 debug_assert_eq!(
2072 self.lowered_runtime().value_program().scratch_slots(),
2073 specialization
2074 .lowered_artifacts
2075 .lowered_runtime
2076 .value_program()
2077 .scratch_slots()
2078 );
2079 }
2080 fn lower_expression_runtime_artifacts(
2081 expression_ir: &ir::ExpressionIR,
2082 values: &[PrecomputedCachedIntegral],
2083 ) -> LadduResult<LoweredArtifactCacheState> {
2084 let parameter_node_indices = values
2085 .iter()
2086 .map(|value| value.parameter_node_index)
2087 .collect();
2088 let mul_node_indices = values.iter().map(|value| value.mul_node_index).collect();
2089 let lowered_parameter_factors = Self::lower_cached_parameter_factors(expression_ir);
2090 let residual_runtime = Self::lower_residual_runtime(expression_ir, values);
2091 let lowered_runtime = lowered::LoweredExpressionRuntime::from_ir_value_gradient(
2092 expression_ir,
2093 )
2094 .map_err(|error| {
2095 LadduError::Custom(format!(
2096 "Failed to lower expression runtime for specialized IR: {error:?}"
2097 ))
2098 })?;
2099 Ok(LoweredArtifactCacheState {
2100 parameter_node_indices,
2101 mul_node_indices,
2102 lowered_parameter_factors,
2103 residual_runtime,
2104 lowered_runtime,
2105 })
2106 }
2107 fn lowered_artifact_signature_matches(
2108 artifacts: &LoweredArtifactCacheState,
2109 values: &[PrecomputedCachedIntegral],
2110 ) -> bool {
2111 artifacts.parameter_node_indices.len() == values.len()
2112 && artifacts.mul_node_indices.len() == values.len()
2113 && artifacts
2114 .parameter_node_indices
2115 .iter()
2116 .copied()
2117 .eq(values.iter().map(|value| value.parameter_node_index))
2118 && artifacts
2119 .mul_node_indices
2120 .iter()
2121 .copied()
2122 .eq(values.iter().map(|value| value.mul_node_index))
2123 }
2124 fn build_expression_specialization(
2125 &self,
2126 resources: &Resources,
2127 key: CachedIntegralCacheKey,
2128 ) -> LadduResult<ExpressionSpecializationState> {
2129 let ir_compile_start = Instant::now();
2130 let expression_ir = self.compile_expression_ir_for_active_mask(&resources.active);
2131 let ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
2132 let cached_integrals_start = Instant::now();
2133 let values = Self::precompute_cached_integrals_at_load(
2134 &expression_ir,
2135 &self.amplitudes,
2136 &self.amplitude_use_sites,
2137 resources,
2138 &self.dataset,
2139 self.resources.read().n_free_parameters(),
2140 )?;
2141 let cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
2142 let execution_sets = expression_ir.normalization_execution_sets().clone();
2143 let active_mask_key = resources.active.clone();
2144 let cached_lowered_artifacts = {
2145 let lowered_artifact_cache = self.ir_planning.lowered_artifact_cache.read();
2146 lowered_artifact_cache
2147 .get(&active_mask_key)
2148 .cloned()
2149 .filter(|artifacts| Self::lowered_artifact_signature_matches(artifacts, &values))
2150 };
2151 let lowered_artifacts = if let Some(artifacts) = cached_lowered_artifacts {
2152 self.ir_planning
2153 .compile_metrics
2154 .write()
2155 .specialization_lowering_cache_hits += 1;
2156 artifacts
2157 } else {
2158 let lowering_start = Instant::now();
2159 let artifacts = Arc::new(
2160 Self::lower_expression_runtime_artifacts(&expression_ir, &values)
2161 .expect("specialized lowered runtime should build"),
2162 );
2163 let lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
2164 self.ir_planning
2165 .lowered_artifact_cache
2166 .write()
2167 .insert(active_mask_key, artifacts.clone());
2168 let mut compile_metrics = self.ir_planning.compile_metrics.write();
2169 compile_metrics.specialization_lowering_cache_misses += 1;
2170 compile_metrics.specialization_lowering_nanos += lowering_nanos;
2171 artifacts
2172 };
2173 let mut compile_metrics = self.ir_planning.compile_metrics.write();
2174 compile_metrics.specialization_cache_misses += 1;
2175 compile_metrics.specialization_ir_compile_nanos += ir_compile_nanos;
2176 compile_metrics.specialization_cached_integrals_nanos += cached_integrals_nanos;
2177 Ok(ExpressionSpecializationState {
2178 cached_integrals: Arc::new(CachedIntegralCacheState {
2179 key,
2180 expression_ir,
2181 values,
2182 execution_sets,
2183 }),
2184 lowered_artifacts,
2185 })
2186 }
2187 fn ensure_expression_specialization(
2188 &self,
2189 resources: &Resources,
2190 ) -> LadduResult<ExpressionSpecializationState> {
2191 let key = Self::cached_integral_cache_key(resources.active.clone(), &self.dataset);
2192 if let Some(state) = self.ir_planning.cached_integrals.read().as_ref() {
2193 if state.key == key {
2194 return Ok(ExpressionSpecializationState {
2195 cached_integrals: state.clone(),
2196 lowered_artifacts: self
2197 .active_lowered_artifacts()
2198 .expect("active lowered artifacts should exist for cached specialization"),
2199 });
2200 }
2201 }
2202 let cached_specialization = {
2203 let specialization_cache = self.ir_planning.specialization_cache.read();
2204 specialization_cache.get(&key).cloned()
2205 };
2206 if let Some(specialization) = cached_specialization {
2207 let restore_start = Instant::now();
2208 self.ir_planning.specialization_metrics.write().cache_hits += 1;
2209 self.install_expression_specialization(&specialization);
2210 *self.ir_planning.specialization_status.write() =
2211 Some(ExpressionSpecializationStatus {
2212 origin: ExpressionSpecializationOrigin::CacheHitRestore,
2213 });
2214 let restore_nanos = restore_start.elapsed().as_nanos() as u64;
2215 let mut compile_metrics = self.ir_planning.compile_metrics.write();
2216 compile_metrics.specialization_cache_hits += 1;
2217 compile_metrics.specialization_cache_restore_nanos += restore_nanos;
2218 return Ok(specialization);
2219 }
2220 let specialization = self.build_expression_specialization(resources, key.clone())?;
2221 self.ir_planning.specialization_metrics.write().cache_misses += 1;
2222 self.ir_planning
2223 .specialization_cache
2224 .write()
2225 .insert(key, specialization.clone());
2226 self.install_expression_specialization(&specialization);
2227 let origin = if self.ir_planning.specialization_cache.read().len() == 1 {
2228 ExpressionSpecializationOrigin::InitialLoad
2229 } else {
2230 ExpressionSpecializationOrigin::CacheMissRebuild
2231 };
2232 *self.ir_planning.specialization_status.write() =
2233 Some(ExpressionSpecializationStatus { origin });
2234 Ok(specialization)
2235 }
2236 fn rebuild_runtime_specializations(&self, resources: &Resources) {
2237 let _ = self.ensure_expression_specialization(resources);
2238 }
2239 fn refresh_runtime_specializations(&self) {
2240 let resources = self.resources.read();
2241 self.rebuild_runtime_specializations(&resources);
2242 }
2243 fn cached_integral_cache_key(
2244 active_mask: Vec<bool>,
2245 dataset: &Dataset,
2246 ) -> CachedIntegralCacheKey {
2247 let (weights_ptr, weights_local_len) = dataset.local_weight_cache_key();
2248 CachedIntegralCacheKey {
2249 active_mask,
2250 n_events_local: dataset.n_events_local(),
2251 weights_local_len,
2252 weighted_sum_bits: dataset.n_events_weighted_local().to_bits(),
2253 weights_ptr,
2254 }
2255 }
2256 fn precompute_cached_integrals_at_load(
2257 expression_ir: &ir::ExpressionIR,
2258 amplitudes: &[Box<dyn Amplitude>],
2259 amplitude_use_sites: &[AmplitudeUseSite],
2260 resources: &Resources,
2261 dataset: &Dataset,
2262 n_free_parameters: usize,
2263 ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
2264 let descriptors = expression_ir.cached_integral_descriptors();
2265 if descriptors.is_empty() {
2266 return Ok(Vec::new());
2267 }
2268 let execution_sets = expression_ir.normalization_execution_sets();
2269 let seed_parameters = vec![0.0; n_free_parameters];
2270 let parameters = resources.parameter_map.assemble(&seed_parameters)?;
2271 let mut amplitude_values = vec![Complex64::ZERO; amplitude_use_sites.len()];
2272 let mut compute_values = vec![Complex64::ZERO; amplitudes.len()];
2273 let mut value_slots = vec![Complex64::ZERO; expression_ir.node_count()];
2274 let active_set = resources.active_indices();
2275 let cache_active_indices = execution_sets
2276 .cached_cache_amplitudes
2277 .iter()
2278 .copied()
2279 .filter(|index| active_set.binary_search(index).is_ok())
2280 .collect::<Vec<_>>();
2281 let mut weighted_cache_sums = vec![Complex64::ZERO; descriptors.len()];
2282 for (cache, event) in resources.caches.iter().zip(dataset.weights_local().iter()) {
2283 amplitude_values.fill(Complex64::ZERO);
2284 compute_values.fill(Complex64::ZERO);
2285 let mut computed = vec![false; amplitudes.len()];
2286 for &use_site_idx in &cache_active_indices {
2287 let amp_idx = amplitude_use_sites[use_site_idx].amplitude_index;
2288 if !computed[amp_idx] {
2289 compute_values[amp_idx] = amplitudes[amp_idx].compute(¶meters, cache);
2290 computed[amp_idx] = true;
2291 }
2292 amplitude_values[use_site_idx] = compute_values[amp_idx];
2293 }
2294 expression_ir.evaluate_into(&litude_values, &mut value_slots);
2295 let weight = *event;
2296 for (descriptor_index, descriptor) in descriptors.iter().enumerate() {
2297 weighted_cache_sums[descriptor_index] +=
2298 value_slots[descriptor.cache_node_index] * weight;
2299 }
2300 }
2301 Ok(descriptors
2302 .iter()
2303 .zip(weighted_cache_sums)
2304 .map(
2305 |(descriptor, weighted_cache_sum)| PrecomputedCachedIntegral {
2306 mul_node_index: descriptor.mul_node_index,
2307 parameter_node_index: descriptor.parameter_node_index,
2308 cache_node_index: descriptor.cache_node_index,
2309 coefficient: descriptor.coefficient,
2310 weighted_cache_sum,
2311 },
2312 )
2313 .collect())
2314 }
2315 fn lower_cached_parameter_factors(
2316 expression_ir: &ir::ExpressionIR,
2317 ) -> Vec<Option<lowered::LoweredFactorRuntime>> {
2318 expression_ir
2319 .cached_integral_descriptors()
2320 .iter()
2321 .map(|descriptor| {
2322 lowered::LoweredFactorRuntime::from_ir_root_value_gradient(
2323 expression_ir,
2324 descriptor.parameter_node_index,
2325 )
2326 .ok()
2327 })
2328 .collect()
2329 }
2330 fn lower_residual_runtime(
2331 expression_ir: &ir::ExpressionIR,
2332 descriptors: &[PrecomputedCachedIntegral],
2333 ) -> Option<lowered::LoweredExpressionRuntime> {
2334 let mut zeroed_nodes = vec![false; expression_ir.node_count()];
2335 for descriptor in descriptors {
2336 if descriptor.mul_node_index < zeroed_nodes.len() {
2337 zeroed_nodes[descriptor.mul_node_index] = true;
2338 }
2339 }
2340 lowered::LoweredExpressionRuntime::from_ir_zeroed_value_gradient(
2341 expression_ir,
2342 &zeroed_nodes,
2343 )
2344 .ok()
2345 }
2346
2347 #[inline]
2348 fn fill_amplitude_values(
2349 &self,
2350 amplitude_values: &mut [Complex64],
2351 active_indices: &[usize],
2352 parameters: &Parameters,
2353 cache: &Cache,
2354 ) {
2355 amplitude_values.fill(Complex64::ZERO);
2356 let mut compute_values = vec![Complex64::ZERO; self.amplitudes.len()];
2357 let mut computed = vec![false; self.amplitudes.len()];
2358 for &use_site_idx in active_indices {
2359 let amp_idx = self.amplitude_use_sites[use_site_idx].amplitude_index;
2360 if !computed[amp_idx] {
2361 compute_values[amp_idx] = self.amplitudes[amp_idx].compute(parameters, cache);
2362 computed[amp_idx] = true;
2363 }
2364 amplitude_values[use_site_idx] = compute_values[amp_idx];
2365 }
2366 }
2367
2368 #[inline]
2369 fn fill_amplitude_gradients(
2370 &self,
2371 gradient_values: &mut [DVector<Complex64>],
2372 active_mask: &[bool],
2373 parameters: &Parameters,
2374 cache: &Cache,
2375 ) {
2376 let mut compute_gradients = vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
2377 let mut computed = vec![false; self.amplitudes.len()];
2378 for ((use_site, active), grad) in self
2379 .amplitude_use_sites
2380 .iter()
2381 .zip(active_mask.iter())
2382 .zip(gradient_values.iter_mut())
2383 {
2384 grad.fill(Complex64::ZERO);
2385 if *active {
2386 let amp_idx = use_site.amplitude_index;
2387 if !computed[amp_idx] {
2388 self.amplitudes[amp_idx].compute_gradient(
2389 parameters,
2390 cache,
2391 &mut compute_gradients[amp_idx],
2392 );
2393 computed[amp_idx] = true;
2394 }
2395 grad.copy_from(&compute_gradients[amp_idx]);
2396 }
2397 }
2398 }
2399
2400 #[inline]
2401 fn fill_amplitude_values_and_gradients(
2402 &self,
2403 amplitude_values: &mut [Complex64],
2404 gradient_values: &mut [DVector<Complex64>],
2405 active_indices: &[usize],
2406 active_mask: &[bool],
2407 parameters: &Parameters,
2408 cache: &Cache,
2409 ) {
2410 self.fill_amplitude_values(amplitude_values, active_indices, parameters, cache);
2411 self.fill_amplitude_gradients(gradient_values, active_mask, parameters, cache);
2412 }
2413
2414 #[doc(hidden)]
2415 pub fn fill_amplitude_values_and_gradients_public(
2416 &self,
2417 amplitude_values: &mut [Complex64],
2418 gradient_values: &mut [DVector<Complex64>],
2419 active_indices: &[usize],
2420 active_mask: &[bool],
2421 parameters: &Parameters,
2422 cache: &Cache,
2423 ) {
2424 self.fill_amplitude_values_and_gradients(
2425 amplitude_values,
2426 gradient_values,
2427 active_indices,
2428 active_mask,
2429 parameters,
2430 cache,
2431 );
2432 }
2433
2434 #[cfg(feature = "execution-context-prototype")]
2435 #[inline]
2436 fn evaluate_cache_gradient_with_scratch(
2437 &self,
2438 amplitude_values: &mut [Complex64],
2439 gradient_values: &mut [DVector<Complex64>],
2440 value_slots: &mut [Complex64],
2441 gradient_slots: &mut [DVector<Complex64>],
2442 active_indices: &[usize],
2443 active_mask: &[bool],
2444 parameters: &Parameters,
2445 cache: &Cache,
2446 ) -> DVector<Complex64> {
2447 self.fill_amplitude_values_and_gradients(
2448 amplitude_values,
2449 gradient_values,
2450 active_indices,
2451 active_mask,
2452 parameters,
2453 cache,
2454 );
2455 self.evaluate_expression_gradient_with_scratch(
2456 amplitude_values,
2457 gradient_values,
2458 value_slots,
2459 gradient_slots,
2460 )
2461 }
2462
2463 #[cfg(feature = "execution-context-prototype")]
2464 #[allow(dead_code)]
2465 #[inline]
2466 fn evaluate_cache_value_gradient_with_scratch(
2467 &self,
2468 amplitude_values: &mut [Complex64],
2469 gradient_values: &mut [DVector<Complex64>],
2470 value_slots: &mut [Complex64],
2471 gradient_slots: &mut [DVector<Complex64>],
2472 active_indices: &[usize],
2473 active_mask: &[bool],
2474 parameters: &Parameters,
2475 cache: &Cache,
2476 ) -> (Complex64, DVector<Complex64>) {
2477 self.fill_amplitude_values_and_gradients(
2478 amplitude_values,
2479 gradient_values,
2480 active_indices,
2481 active_mask,
2482 parameters,
2483 cache,
2484 );
2485 self.evaluate_expression_value_gradient_with_scratch(
2486 amplitude_values,
2487 gradient_values,
2488 value_slots,
2489 gradient_slots,
2490 )
2491 }
2492
2493 pub fn expression_slot_count(&self) -> usize {
2494 self.lowered_runtime_slot_count()
2495 }
2496 fn compile_expression_ir_for_active_mask(&self, active_mask: &[bool]) -> ir::ExpressionIR {
2497 let amplitude_dependencies = self
2498 .amplitude_use_sites
2499 .iter()
2500 .map(|use_site| {
2501 ir::DependenceClass::from(
2502 self.amplitudes[use_site.amplitude_index].dependence_hint(),
2503 )
2504 })
2505 .collect::<Vec<_>>();
2506 let amplitude_realness = self
2507 .amplitude_use_sites
2508 .iter()
2509 .map(|use_site| self.amplitudes[use_site.amplitude_index].real_valued_hint())
2510 .collect::<Vec<_>>();
2511 ir::compile_expression_ir_with_real_hints(
2512 &self.expression,
2513 active_mask,
2514 &litude_dependencies,
2515 &litude_realness,
2516 )
2517 }
2518 fn lower_expression_runtime_for_active_mask(
2519 &self,
2520 active_mask: &[bool],
2521 ) -> LadduResult<lowered::LoweredExpressionRuntime> {
2522 let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
2523 lowered::LoweredExpressionRuntime::from_ir_value_gradient(&expression_ir).map_err(|error| {
2524 LadduError::Custom(format!(
2525 "Failed to lower active-mask runtime specialization: {error:?}"
2526 ))
2527 })
2528 }
2529 fn ensure_cached_integral_cache_state(
2530 &self,
2531 resources: &Resources,
2532 ) -> LadduResult<Arc<CachedIntegralCacheState>> {
2533 Ok(self
2534 .ensure_expression_specialization(resources)?
2535 .cached_integrals)
2536 }
2537
2538 fn evaluate_expression_runtime_value_with_scratch(
2539 &self,
2540 amplitude_values: &[Complex64],
2541 scratch: &mut [Complex64],
2542 ) -> Complex64 {
2543 let lowered_runtime = self.lowered_runtime();
2544 lowered_runtime
2545 .value_program()
2546 .evaluate_into(amplitude_values, scratch)
2547 }
2548
2549 #[doc(hidden)]
2550 pub fn evaluate_expression_value_with_program_snapshot(
2551 &self,
2552 program_snapshot: &ExpressionValueProgramSnapshot,
2553 amplitude_values: &[Complex64],
2554 scratch: &mut [Complex64],
2555 ) -> Complex64 {
2556 program_snapshot
2557 .lowered_program
2558 .evaluate_into(amplitude_values, scratch)
2559 }
2560
2561 fn evaluate_expression_runtime_gradient_with_scratch(
2562 &self,
2563 amplitude_values: &[Complex64],
2564 gradient_values: &[DVector<Complex64>],
2565 value_scratch: &mut [Complex64],
2566 gradient_scratch: &mut [DVector<Complex64>],
2567 ) -> DVector<Complex64> {
2568 let lowered_runtime = self.lowered_runtime();
2569 lowered_runtime.gradient_program().evaluate_gradient_into(
2570 amplitude_values,
2571 gradient_values,
2572 value_scratch,
2573 gradient_scratch,
2574 )
2575 }
2576
2577 fn evaluate_expression_runtime_value_gradient_with_scratch(
2578 &self,
2579 amplitude_values: &[Complex64],
2580 gradient_values: &[DVector<Complex64>],
2581 value_scratch: &mut [Complex64],
2582 gradient_scratch: &mut [DVector<Complex64>],
2583 ) -> (Complex64, DVector<Complex64>) {
2584 let lowered_runtime = self.lowered_runtime();
2585 lowered_runtime
2586 .value_gradient_program()
2587 .evaluate_value_gradient_into(
2588 amplitude_values,
2589 gradient_values,
2590 value_scratch,
2591 gradient_scratch,
2592 )
2593 }
2594
2595 fn evaluate_expression_runtime_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
2596 let lowered_runtime = self.lowered_runtime();
2597 let program = lowered_runtime.value_program();
2598 let mut scratch = vec![Complex64::ZERO; program.scratch_slots()];
2599 program.evaluate_into(amplitude_values, &mut scratch)
2600 }
2601
2602 fn evaluate_expression_runtime_gradient(
2603 &self,
2604 amplitude_values: &[Complex64],
2605 gradient_values: &[DVector<Complex64>],
2606 ) -> DVector<Complex64> {
2607 let lowered_runtime = self.lowered_runtime();
2608 let program = lowered_runtime.gradient_program();
2609 let mut value_scratch = vec![Complex64::ZERO; program.scratch_slots()];
2610 let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
2611 let mut gradient_scratch = vec![Complex64::ZERO; program.scratch_slots() * grad_dim];
2612 program.evaluate_gradient_into_flat(
2613 amplitude_values,
2614 gradient_values,
2615 &mut value_scratch,
2616 &mut gradient_scratch,
2617 grad_dim,
2618 )
2619 }
2620 pub fn expression_root_dependence(&self) -> LadduResult<ExpressionDependence> {
2622 let resources = self.resources.read();
2623 Ok(self
2624 .ensure_cached_integral_cache_state(&resources)?
2625 .expression_ir
2626 .root_dependence()
2627 .into())
2628 }
2629 pub fn expression_node_dependence_annotations(&self) -> LadduResult<Vec<ExpressionDependence>> {
2631 let resources = self.resources.read();
2632 Ok(self
2633 .ensure_cached_integral_cache_state(&resources)?
2634 .expression_ir
2635 .node_dependence_annotations()
2636 .iter()
2637 .copied()
2638 .map(Into::into)
2639 .collect())
2640 }
2641 pub fn expression_dependence_warnings(&self) -> LadduResult<Vec<String>> {
2643 let resources = self.resources.read();
2644 Ok(self
2645 .ensure_cached_integral_cache_state(&resources)?
2646 .expression_ir
2647 .dependence_warnings()
2648 .to_vec())
2649 }
2650 pub fn expression_normalization_plan_explain(&self) -> LadduResult<NormalizationPlanExplain> {
2652 let resources = self.resources.read();
2653 Ok(self
2654 .ensure_cached_integral_cache_state(&resources)?
2655 .expression_ir
2656 .normalization_plan_explain()
2657 .into())
2658 }
2659 pub fn expression_normalization_execution_sets(
2661 &self,
2662 ) -> LadduResult<NormalizationExecutionSetsExplain> {
2663 let resources = self.resources.read();
2664 Ok(self
2665 .ensure_cached_integral_cache_state(&resources)?
2666 .execution_sets
2667 .clone()
2668 .into())
2669 }
2670 pub fn expression_precomputed_cached_integrals(
2672 &self,
2673 ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
2674 let resources = self.resources.read();
2675 Ok(self
2676 .ensure_cached_integral_cache_state(&resources)?
2677 .values
2678 .clone())
2679 }
2680 pub fn expression_precomputed_cached_integral_gradient_terms(
2685 &self,
2686 parameters: &[f64],
2687 ) -> LadduResult<Vec<PrecomputedCachedIntegralGradientTerm>> {
2688 let resources = self.resources.read();
2689 let state = self.ensure_cached_integral_cache_state(&resources)?;
2690 if state.values.is_empty() {
2691 return Ok(Vec::new());
2692 }
2693
2694 let Some(cache) = resources.caches.first() else {
2695 return Ok(state
2696 .values
2697 .iter()
2698 .map(|descriptor| PrecomputedCachedIntegralGradientTerm {
2699 mul_node_index: descriptor.mul_node_index,
2700 parameter_node_index: descriptor.parameter_node_index,
2701 cache_node_index: descriptor.cache_node_index,
2702 coefficient: descriptor.coefficient,
2703 weighted_gradient: DVector::zeros(parameters.len()),
2704 })
2705 .collect());
2706 };
2707
2708 let parameter_values = resources.parameter_map.assemble(parameters)?;
2709 let mut amplitude_values = vec![Complex64::ZERO; self.amplitude_use_sites.len()];
2710 self.fill_amplitude_values(
2711 &mut amplitude_values,
2712 resources.active_indices(),
2713 ¶meter_values,
2714 cache,
2715 );
2716 let mut amplitude_gradients = (0..self.amplitude_use_sites.len())
2717 .map(|_| DVector::zeros(parameters.len()))
2718 .collect::<Vec<_>>();
2719 self.fill_amplitude_gradients(
2720 &mut amplitude_gradients,
2721 &resources.active,
2722 ¶meter_values,
2723 cache,
2724 );
2725 let lowered_artifacts = self.active_lowered_artifacts();
2726 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2727 let mut gradient_slots = (0..state.expression_ir.node_count())
2728 .map(|_| DVector::zeros(parameters.len()))
2729 .collect::<Vec<_>>();
2730 let max_lowered_slots = lowered_artifacts
2731 .as_ref()
2732 .map(|artifacts| {
2733 artifacts
2734 .lowered_parameter_factors
2735 .iter()
2736 .filter_map(|runtime| {
2737 runtime
2738 .as_ref()
2739 .and_then(|runtime| runtime.gradient_program())
2740 .map(|program| program.scratch_slots())
2741 })
2742 .max()
2743 .unwrap_or(0)
2744 })
2745 .unwrap_or(0);
2746 let mut lowered_value_slots = vec![Complex64::ZERO; max_lowered_slots];
2747 let mut lowered_gradient_slots = vec![DVector::zeros(parameters.len()); max_lowered_slots];
2748 let use_lowered = lowered_artifacts.as_ref().is_some_and(|artifacts| {
2749 artifacts.lowered_parameter_factors.len() == state.values.len()
2750 && artifacts.lowered_parameter_factors.iter().all(|runtime| {
2751 runtime
2752 .as_ref()
2753 .and_then(|runtime| runtime.gradient_program())
2754 .is_some()
2755 })
2756 });
2757
2758 if !use_lowered {
2759 let _ = state.expression_ir.evaluate_gradient_into(
2760 &litude_values,
2761 &litude_gradients,
2762 &mut value_slots,
2763 &mut gradient_slots,
2764 );
2765 }
2766
2767 if use_lowered {
2768 let lowered_artifacts = lowered_artifacts.expect("lowered artifacts should exist");
2769 Ok(state
2770 .values
2771 .iter()
2772 .cloned()
2773 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2774 .map(|(descriptor, runtime)| {
2775 let parameter_gradient = runtime
2776 .as_ref()
2777 .and_then(|runtime| runtime.gradient_program())
2778 .map(|program| {
2779 program.evaluate_gradient_into(
2780 &litude_values,
2781 &litude_gradients,
2782 &mut lowered_value_slots[..program.scratch_slots()],
2783 &mut lowered_gradient_slots[..program.scratch_slots()],
2784 )
2785 })
2786 .unwrap_or_else(|| DVector::zeros(parameters.len()));
2787 let weighted_gradient = parameter_gradient.map(|value| {
2788 value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2789 });
2790 PrecomputedCachedIntegralGradientTerm {
2791 mul_node_index: descriptor.mul_node_index,
2792 parameter_node_index: descriptor.parameter_node_index,
2793 cache_node_index: descriptor.cache_node_index,
2794 coefficient: descriptor.coefficient,
2795 weighted_gradient,
2796 }
2797 })
2798 .collect())
2799 } else {
2800 Ok(state
2801 .values
2802 .iter()
2803 .map(|descriptor| {
2804 let parameter_gradient = gradient_slots
2805 .get(descriptor.parameter_node_index)
2806 .cloned()
2807 .unwrap_or_else(|| DVector::zeros(parameters.len()));
2808 let weighted_gradient = parameter_gradient.map(|value| {
2809 value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
2810 });
2811 PrecomputedCachedIntegralGradientTerm {
2812 mul_node_index: descriptor.mul_node_index,
2813 parameter_node_index: descriptor.parameter_node_index,
2814 cache_node_index: descriptor.cache_node_index,
2815 coefficient: descriptor.coefficient,
2816 weighted_gradient,
2817 }
2818 })
2819 .collect())
2820 }
2821 }
2822 fn evaluate_cached_weighted_value_sum_ir(
2823 &self,
2824 state: &CachedIntegralCacheState,
2825 amplitude_values: &[Complex64],
2826 ) -> f64 {
2827 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2828 let _ = state
2829 .expression_ir
2830 .evaluate_into(amplitude_values, &mut value_slots);
2831 state
2832 .values
2833 .iter()
2834 .map(|descriptor| {
2835 let parameter_factor = value_slots[descriptor.parameter_node_index];
2836 (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2837 .re
2838 })
2839 .sum()
2840 }
2841 fn evaluate_cached_weighted_value_sum_lowered(
2842 &self,
2843 state: &CachedIntegralCacheState,
2844 lowered_artifacts: &LoweredArtifactCacheState,
2845 amplitude_values: &[Complex64],
2846 ) -> Option<f64> {
2847 let max_slots = lowered_artifacts
2848 .lowered_parameter_factors
2849 .iter()
2850 .filter_map(|runtime| {
2851 runtime
2852 .as_ref()
2853 .and_then(|runtime| runtime.value_program())
2854 .map(|program| program.scratch_slots())
2855 })
2856 .max()
2857 .unwrap_or(0);
2858 let mut value_slots = vec![Complex64::ZERO; max_slots];
2859 let mut total = 0.0;
2860 for (descriptor, runtime) in state
2861 .values
2862 .iter()
2863 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2864 {
2865 let parameter_factor = runtime
2866 .as_ref()
2867 .and_then(|runtime| runtime.value_program())
2868 .map(|program| {
2869 program.evaluate_into(
2870 amplitude_values,
2871 &mut value_slots[..program.scratch_slots()],
2872 )
2873 })?;
2874 total +=
2875 (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
2876 .re;
2877 }
2878 Some(total)
2879 }
2880 fn evaluate_cached_weighted_gradient_sum_ir(
2881 &self,
2882 state: &CachedIntegralCacheState,
2883 amplitude_values: &[Complex64],
2884 amplitude_gradients: &[DVector<Complex64>],
2885 grad_dim: usize,
2886 ) -> DVector<f64> {
2887 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2888 let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2889 let _ = state.expression_ir.evaluate_gradient_into(
2890 amplitude_values,
2891 amplitude_gradients,
2892 &mut value_slots,
2893 &mut gradient_slots,
2894 );
2895 state
2896 .values
2897 .iter()
2898 .fold(DVector::zeros(grad_dim), |mut accum, descriptor| {
2899 let parameter_gradient = &gradient_slots[descriptor.parameter_node_index];
2900 let coefficient = descriptor.coefficient as f64;
2901 for (accum_item, gradient_item) in accum.iter_mut().zip(parameter_gradient.iter()) {
2902 *accum_item +=
2903 (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2904 }
2905 accum
2906 })
2907 }
2908 fn evaluate_cached_weighted_gradient_sum_lowered(
2909 &self,
2910 state: &CachedIntegralCacheState,
2911 lowered_artifacts: &LoweredArtifactCacheState,
2912 amplitude_values: &[Complex64],
2913 amplitude_gradients: &[DVector<Complex64>],
2914 grad_dim: usize,
2915 ) -> Option<DVector<f64>> {
2916 let max_value_slots = lowered_artifacts
2917 .lowered_parameter_factors
2918 .iter()
2919 .filter_map(|runtime| {
2920 runtime
2921 .as_ref()
2922 .and_then(|runtime| runtime.gradient_program())
2923 .map(|program| program.scratch_slots())
2924 })
2925 .max()
2926 .unwrap_or(0);
2927 let mut value_slots = vec![Complex64::ZERO; max_value_slots];
2928 let mut gradient_slots = vec![Complex64::ZERO; max_value_slots * grad_dim];
2929 let mut total = DVector::zeros(grad_dim);
2930 for (descriptor, runtime) in state
2931 .values
2932 .iter()
2933 .zip(lowered_artifacts.lowered_parameter_factors.iter())
2934 {
2935 let parameter_gradient = runtime
2936 .as_ref()
2937 .and_then(|runtime| runtime.gradient_program())
2938 .map(|program| {
2939 program.evaluate_gradient_into_flat(
2940 amplitude_values,
2941 amplitude_gradients,
2942 &mut value_slots[..program.scratch_slots()],
2943 &mut gradient_slots[..program.scratch_slots() * grad_dim],
2944 grad_dim,
2945 )
2946 })?;
2947 let coefficient = descriptor.coefficient as f64;
2948 for (accum_item, gradient_item) in total.iter_mut().zip(parameter_gradient.iter()) {
2949 *accum_item += (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
2950 }
2951 }
2952 Some(total)
2953 }
2954 fn evaluate_residual_value_ir(
2955 &self,
2956 state: &CachedIntegralCacheState,
2957 amplitude_values: &[Complex64],
2958 ) -> Complex64 {
2959 let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2960 for descriptor in &state.values {
2961 if descriptor.mul_node_index < zeroed_nodes.len() {
2962 zeroed_nodes[descriptor.mul_node_index] = true;
2963 }
2964 }
2965 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2966 state.expression_ir.evaluate_into_with_zeroed_nodes(
2967 amplitude_values,
2968 &mut value_slots,
2969 &zeroed_nodes,
2970 )
2971 }
2972 fn evaluate_residual_gradient_ir(
2973 &self,
2974 state: &CachedIntegralCacheState,
2975 amplitude_values: &[Complex64],
2976 amplitude_gradients: &[DVector<Complex64>],
2977 grad_dim: usize,
2978 ) -> DVector<Complex64> {
2979 let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
2980 for descriptor in &state.values {
2981 if descriptor.mul_node_index < zeroed_nodes.len() {
2982 zeroed_nodes[descriptor.mul_node_index] = true;
2983 }
2984 }
2985 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
2986 let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
2987 state
2988 .expression_ir
2989 .evaluate_gradient_into_with_zeroed_nodes(
2990 amplitude_values,
2991 amplitude_gradients,
2992 &mut value_slots,
2993 &mut gradient_slots,
2994 &zeroed_nodes,
2995 )
2996 }
2997
2998 fn evaluate_weighted_value_sum_local_components(
2999 &self,
3000 parameters: &[f64],
3001 ) -> LadduResult<(f64, f64)> {
3002 let resources = self.resources.read();
3003 let parameters = resources.parameter_map.assemble(parameters)?;
3004 let amplitude_len = self.amplitude_use_sites.len();
3005 let state = self.ensure_cached_integral_cache_state(&resources)?;
3006 let lowered_artifacts = self.active_lowered_artifacts();
3007 let residual_value_slot_count = lowered_artifacts
3008 .as_ref()
3009 .and_then(|artifacts| {
3010 artifacts
3011 .residual_runtime
3012 .as_ref()
3013 .map(|runtime| runtime.value_program())
3014 .map(|program| program.scratch_slots())
3015 })
3016 .unwrap_or_else(|| self.expression_slot_count());
3017 let residual_value_program = lowered_artifacts
3018 .as_ref()
3019 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
3020 .map(|runtime| runtime.value_program());
3021 let cached_parameter_indices = &state.execution_sets.cached_parameter_amplitudes;
3022 let residual_active_indices = &state.execution_sets.residual_amplitudes;
3023 debug_assert!(cached_parameter_indices.iter().all(|&index| resources
3024 .active
3025 .get(index)
3026 .copied()
3027 .unwrap_or(false)));
3028 debug_assert!(residual_active_indices.iter().all(|&index| resources
3029 .active
3030 .get(index)
3031 .copied()
3032 .unwrap_or(false)));
3033 let cached_value_sum = {
3034 if let Some(cache) = resources.caches.first() {
3035 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3036 self.fill_amplitude_values(
3037 &mut amplitude_values,
3038 cached_parameter_indices,
3039 ¶meters,
3040 cache,
3041 );
3042 lowered_artifacts
3043 .as_ref()
3044 .and_then(|artifacts| {
3045 self.evaluate_cached_weighted_value_sum_lowered(
3046 &state,
3047 artifacts,
3048 &litude_values,
3049 )
3050 })
3051 .unwrap_or_else(|| {
3052 self.evaluate_cached_weighted_value_sum_ir(&state, &litude_values)
3053 })
3054 } else {
3055 0.0
3056 }
3057 };
3058
3059 #[cfg(feature = "rayon")]
3060 let residual_sum: f64 = {
3061 resources
3062 .caches
3063 .par_iter()
3064 .zip(self.dataset.weights_local().par_iter())
3065 .map_init(
3066 || {
3067 (
3068 vec![Complex64::ZERO; amplitude_len],
3069 vec![Complex64::ZERO; residual_value_slot_count],
3070 )
3071 },
3072 |(amplitude_values, value_slots), (cache, event)| {
3073 self.fill_amplitude_values(
3074 amplitude_values,
3075 residual_active_indices,
3076 ¶meters,
3077 cache,
3078 );
3079 {
3080 let value = residual_value_program
3081 .as_ref()
3082 .map(|program| {
3083 program.evaluate_into(
3084 amplitude_values,
3085 &mut value_slots[..program.scratch_slots()],
3086 )
3087 })
3088 .unwrap_or_else(|| {
3089 self.evaluate_residual_value_ir(&state, amplitude_values)
3090 });
3091 *event * value.re
3092 }
3093 },
3094 )
3095 .sum()
3096 };
3097
3098 #[cfg(not(feature = "rayon"))]
3099 let residual_sum: f64 = {
3100 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3101 let mut value_slots = vec![Complex64::ZERO; residual_value_slot_count];
3102 resources
3103 .caches
3104 .iter()
3105 .zip(self.dataset.weights_local().iter())
3106 .map(|(cache, event)| {
3107 self.fill_amplitude_values(
3108 &mut amplitude_values,
3109 &residual_active_indices,
3110 ¶meters,
3111 cache,
3112 );
3113 {
3114 let value = residual_value_program
3115 .as_ref()
3116 .map(|program| {
3117 program.evaluate_into(
3118 &litude_values,
3119 &mut value_slots[..program.scratch_slots()],
3120 )
3121 })
3122 .unwrap_or_else(|| {
3123 self.evaluate_residual_value_ir(&state, &litude_values)
3124 });
3125 *event * value.re
3126 }
3127 })
3128 .sum()
3129 };
3130 Ok((residual_sum, cached_value_sum))
3131 }
3132
3133 pub fn evaluate_weighted_value_sum_local(&self, parameters: &[f64]) -> LadduResult<f64> {
3137 let (residual_sum, cached_value_sum) =
3138 self.evaluate_weighted_value_sum_local_components(parameters)?;
3139 Ok(residual_sum + cached_value_sum)
3140 }
3141
3142 #[cfg(feature = "mpi")]
3143 pub fn evaluate_weighted_value_sum_mpi(
3147 &self,
3148 parameters: &[f64],
3149 world: &SimpleCommunicator,
3150 ) -> LadduResult<f64> {
3151 let (residual_sum_local, cached_value_sum_local) =
3152 self.evaluate_weighted_value_sum_local_components(parameters)?;
3153 let mut residual_sum = 0.0;
3154 world.all_reduce_into(
3155 &residual_sum_local,
3156 &mut residual_sum,
3157 mpi::collective::SystemOperation::sum(),
3158 );
3159 let mut cached_value_sum = 0.0;
3160 world.all_reduce_into(
3161 &cached_value_sum_local,
3162 &mut cached_value_sum,
3163 mpi::collective::SystemOperation::sum(),
3164 );
3165 Ok(residual_sum + cached_value_sum)
3166 }
3167
3168 fn evaluate_weighted_gradient_sum_local_components(
3172 &self,
3173 parameters: &[f64],
3174 ) -> LadduResult<(DVector<f64>, DVector<f64>)> {
3175 let resources = self.resources.read();
3176 let parameters = resources.parameter_map.assemble(parameters)?;
3177 let amplitude_len = self.amplitude_use_sites.len();
3178 let grad_dim = parameters.len();
3179 let state = self.ensure_cached_integral_cache_state(&resources)?;
3180 let lowered_artifacts = self.active_lowered_artifacts();
3181 let active_index_set = resources.active_indices();
3182 let cached_parameter_indices = state
3183 .execution_sets
3184 .cached_parameter_amplitudes
3185 .iter()
3186 .copied()
3187 .filter(|index| active_index_set.binary_search(index).is_ok())
3188 .collect::<Vec<_>>();
3189 let residual_active_indices = state
3190 .execution_sets
3191 .residual_amplitudes
3192 .iter()
3193 .copied()
3194 .filter(|index| active_index_set.binary_search(index).is_ok())
3195 .collect::<Vec<_>>();
3196 let mut cached_parameter_mask = vec![false; amplitude_len];
3197 for &index in &cached_parameter_indices {
3198 cached_parameter_mask[index] = true;
3199 }
3200 let mut residual_active_mask = vec![false; amplitude_len];
3201 for &index in &residual_active_indices {
3202 residual_active_mask[index] = true;
3203 }
3204 let residual_gradient_program = lowered_artifacts
3205 .as_ref()
3206 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
3207 .map(|runtime| runtime.gradient_program());
3208 let residual_gradient_slot_count = residual_gradient_program
3209 .as_ref()
3210 .map(|program| program.scratch_slots())
3211 .unwrap_or_else(|| state.expression_ir.node_count());
3212 let cached_term_sum = {
3213 if let Some(cache) = resources.caches.first() {
3214 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3215 self.fill_amplitude_values(
3216 &mut amplitude_values,
3217 &cached_parameter_indices,
3218 ¶meters,
3219 cache,
3220 );
3221 let mut amplitude_gradients = (0..amplitude_len)
3222 .map(|_| DVector::zeros(grad_dim))
3223 .collect::<Vec<_>>();
3224 self.fill_amplitude_gradients(
3225 &mut amplitude_gradients,
3226 &cached_parameter_mask,
3227 ¶meters,
3228 cache,
3229 );
3230 lowered_artifacts
3231 .as_ref()
3232 .and_then(|artifacts| {
3233 self.evaluate_cached_weighted_gradient_sum_lowered(
3234 &state,
3235 artifacts,
3236 &litude_values,
3237 &litude_gradients,
3238 grad_dim,
3239 )
3240 })
3241 .unwrap_or_else(|| {
3242 self.evaluate_cached_weighted_gradient_sum_ir(
3243 &state,
3244 &litude_values,
3245 &litude_gradients,
3246 grad_dim,
3247 )
3248 })
3249 } else {
3250 DVector::zeros(grad_dim)
3251 }
3252 };
3253
3254 #[cfg(feature = "rayon")]
3255 let residual_sum = {
3256 resources
3257 .caches
3258 .par_iter()
3259 .zip(self.dataset.weights_local().par_iter())
3260 .map_init(
3261 || {
3262 (
3263 vec![Complex64::ZERO; amplitude_len],
3264 vec![DVector::zeros(grad_dim); amplitude_len],
3265 vec![Complex64::ZERO; residual_gradient_slot_count],
3266 vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim],
3267 )
3268 },
3269 |(amplitude_values, gradient_values, value_slots, gradient_slots),
3270 (cache, event)| {
3271 self.fill_amplitude_values_and_gradients(
3272 amplitude_values,
3273 gradient_values,
3274 &residual_active_indices,
3275 &residual_active_mask,
3276 ¶meters,
3277 cache,
3278 );
3279 let gradient = residual_gradient_program
3280 .as_ref()
3281 .map(|program| {
3282 program.evaluate_gradient_into_flat(
3283 amplitude_values,
3284 gradient_values,
3285 value_slots,
3286 gradient_slots,
3287 grad_dim,
3288 )
3289 })
3290 .unwrap_or_else(|| {
3291 self.evaluate_residual_gradient_ir(
3292 &state,
3293 amplitude_values,
3294 gradient_values,
3295 grad_dim,
3296 )
3297 });
3298 gradient.map(|value| value.re).scale(*event)
3299 },
3300 )
3301 .reduce(
3302 || DVector::zeros(grad_dim),
3303 |mut accum, value| {
3304 accum += value;
3305 accum
3306 },
3307 )
3308 };
3309
3310 #[cfg(not(feature = "rayon"))]
3311 let residual_sum = {
3312 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3313 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
3314 let mut value_slots = vec![Complex64::ZERO; residual_gradient_slot_count];
3315 let mut gradient_slots = vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim];
3316 resources
3317 .caches
3318 .iter()
3319 .zip(self.dataset.weights_local().iter())
3320 .map(|(cache, event)| {
3321 self.fill_amplitude_values_and_gradients(
3322 &mut amplitude_values,
3323 &mut gradient_values,
3324 &residual_active_indices,
3325 &residual_active_mask,
3326 ¶meters,
3327 cache,
3328 );
3329 let gradient = residual_gradient_program
3330 .as_ref()
3331 .map(|program| {
3332 program.evaluate_gradient_into_flat(
3333 &litude_values,
3334 &gradient_values,
3335 &mut value_slots,
3336 &mut gradient_slots,
3337 grad_dim,
3338 )
3339 })
3340 .unwrap_or_else(|| {
3341 self.evaluate_residual_gradient_ir(
3342 &state,
3343 &litude_values,
3344 &gradient_values,
3345 grad_dim,
3346 )
3347 });
3348 gradient.map(|value| value.re).scale(*event)
3349 })
3350 .sum()
3351 };
3352 Ok((residual_sum, cached_term_sum))
3353 }
3354
3355 pub fn evaluate_weighted_gradient_sum_local(
3359 &self,
3360 parameters: &[f64],
3361 ) -> LadduResult<DVector<f64>> {
3362 let (residual_sum, cached_term_sum) =
3363 self.evaluate_weighted_gradient_sum_local_components(parameters)?;
3364 Ok(residual_sum + cached_term_sum)
3365 }
3366
3367 #[cfg(feature = "mpi")]
3368 pub fn evaluate_weighted_gradient_sum_mpi(
3372 &self,
3373 parameters: &[f64],
3374 world: &SimpleCommunicator,
3375 ) -> LadduResult<DVector<f64>> {
3376 let (residual_sum_local, cached_term_sum_local) =
3377 self.evaluate_weighted_gradient_sum_local_components(parameters)?;
3378 let mut residual_sum = vec![0.0; residual_sum_local.len()];
3379 world.all_reduce_into(
3380 residual_sum_local.as_slice(),
3381 &mut residual_sum,
3382 mpi::collective::SystemOperation::sum(),
3383 );
3384 let mut cached_term_sum = vec![0.0; cached_term_sum_local.len()];
3385 world.all_reduce_into(
3386 cached_term_sum_local.as_slice(),
3387 &mut cached_term_sum,
3388 mpi::collective::SystemOperation::sum(),
3389 );
3390 let mut total = DVector::from_vec(residual_sum);
3391 total += DVector::from_vec(cached_term_sum);
3392 Ok(total)
3393 }
3394
3395 pub fn evaluate_expression_value_with_scratch(
3396 &self,
3397 amplitude_values: &[Complex64],
3398 scratch: &mut [Complex64],
3399 ) -> Complex64 {
3400 self.evaluate_expression_runtime_value_with_scratch(amplitude_values, scratch)
3401 }
3402
3403 pub fn evaluate_expression_gradient_with_scratch(
3404 &self,
3405 amplitude_values: &[Complex64],
3406 gradient_values: &[DVector<Complex64>],
3407 value_scratch: &mut [Complex64],
3408 gradient_scratch: &mut [DVector<Complex64>],
3409 ) -> DVector<Complex64> {
3410 self.evaluate_expression_runtime_gradient_with_scratch(
3411 amplitude_values,
3412 gradient_values,
3413 value_scratch,
3414 gradient_scratch,
3415 )
3416 }
3417
3418 pub fn evaluate_expression_value_gradient_with_scratch(
3419 &self,
3420 amplitude_values: &[Complex64],
3421 gradient_values: &[DVector<Complex64>],
3422 value_scratch: &mut [Complex64],
3423 gradient_scratch: &mut [DVector<Complex64>],
3424 ) -> (Complex64, DVector<Complex64>) {
3425 self.evaluate_expression_runtime_value_gradient_with_scratch(
3426 amplitude_values,
3427 gradient_values,
3428 value_scratch,
3429 gradient_scratch,
3430 )
3431 }
3432
3433 pub fn evaluate_expression_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
3434 self.evaluate_expression_runtime_value(amplitude_values)
3435 }
3436
3437 pub fn evaluate_expression_gradient(
3438 &self,
3439 amplitude_values: &[Complex64],
3440 gradient_values: &[DVector<Complex64>],
3441 ) -> DVector<Complex64> {
3442 self.evaluate_expression_runtime_gradient(amplitude_values, gradient_values)
3443 }
3444
3445 pub fn parameters(&self) -> ParameterMap {
3447 self.resources.read().parameters()
3448 }
3449
3450 pub fn n_free(&self) -> usize {
3452 self.resources.read().n_free_parameters()
3453 }
3454
3455 pub fn n_fixed(&self) -> usize {
3457 self.resources.read().n_fixed_parameters()
3458 }
3459
3460 pub fn n_parameters(&self) -> usize {
3462 self.resources.read().n_parameters()
3463 }
3464
3465 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
3466 self.resources.read().fix_parameter(name, value)
3467 }
3468
3469 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
3470 self.resources.read().free_parameter(name)
3471 }
3472
3473 pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
3474 self.resources.write().rename_parameter(old, new)
3475 }
3476
3477 pub fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
3478 self.resources.write().rename_parameters(mapping)
3479 }
3480
3481 pub fn activate<T: AsRef<str>>(&self, name: T) {
3483 self.resources.write().activate(name);
3484 self.refresh_runtime_specializations();
3485 }
3486 pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
3488 self.resources.write().activate_strict(name)?;
3489 self.refresh_runtime_specializations();
3490 Ok(())
3491 }
3492
3493 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
3495 self.resources.write().activate_many(names);
3496 self.refresh_runtime_specializations();
3497 }
3498 pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
3500 self.resources.write().activate_many_strict(names)?;
3501 self.refresh_runtime_specializations();
3502 Ok(())
3503 }
3504
3505 pub fn activate_all(&self) {
3507 self.resources.write().activate_all();
3508 self.refresh_runtime_specializations();
3509 }
3510
3511 pub fn deactivate<T: AsRef<str>>(&self, name: T) {
3513 self.resources.write().deactivate(name);
3514 self.refresh_runtime_specializations();
3515 }
3516
3517 pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
3519 self.resources.write().deactivate_strict(name)?;
3520 self.refresh_runtime_specializations();
3521 Ok(())
3522 }
3523
3524 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
3526 self.resources.write().deactivate_many(names);
3527 self.refresh_runtime_specializations();
3528 }
3529 pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
3531 self.resources.write().deactivate_many_strict(names)?;
3532 self.refresh_runtime_specializations();
3533 Ok(())
3534 }
3535
3536 pub fn deactivate_all(&self) {
3538 self.resources.write().deactivate_all();
3539 self.refresh_runtime_specializations();
3540 }
3541
3542 pub fn isolate<T: AsRef<str>>(&self, name: T) {
3544 self.resources.write().isolate(name);
3545 self.refresh_runtime_specializations();
3546 }
3547
3548 pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
3550 self.resources.write().isolate_strict(name)?;
3551 self.refresh_runtime_specializations();
3552 Ok(())
3553 }
3554
3555 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
3557 self.resources.write().isolate_many(names);
3558 self.refresh_runtime_specializations();
3559 }
3560
3561 pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
3563 self.resources.write().isolate_many_strict(names)?;
3564 self.refresh_runtime_specializations();
3565 Ok(())
3566 }
3567
3568 pub fn active_mask(&self) -> Vec<bool> {
3570 self.resources.read().active.clone()
3571 }
3572
3573 pub fn set_active_mask(&self, mask: &[bool]) -> LadduResult<()> {
3575 let resources = {
3576 let mut resources = self.resources.write();
3577 if mask.len() != resources.active.len() {
3578 return Err(LadduError::LengthMismatch {
3579 context: "active amplitude mask".to_string(),
3580 expected: resources.active.len(),
3581 actual: mask.len(),
3582 });
3583 }
3584 resources.apply_active_mask(mask)?;
3585 resources.clone()
3586 };
3587 self.rebuild_runtime_specializations(&resources);
3588 Ok(())
3589 }
3590
3591 pub fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
3599 let resources = self.resources.read();
3600 let parameters = resources.parameter_map.assemble(parameters)?;
3601 let amplitude_len = self.amplitude_use_sites.len();
3602 let active_indices = resources.active_indices().to_vec();
3603 let slot_count = self.expression_value_slot_count();
3604 let program_snapshot = self.expression_value_program_snapshot();
3605 #[cfg(feature = "rayon")]
3606 {
3607 Ok(resources
3608 .caches
3609 .par_iter()
3610 .map_init(
3611 || {
3612 (
3613 vec![Complex64::ZERO; amplitude_len],
3614 vec![Complex64::ZERO; slot_count],
3615 )
3616 },
3617 |(amplitude_values, expr_slots), cache| {
3618 self.fill_amplitude_values(
3619 amplitude_values,
3620 &active_indices,
3621 ¶meters,
3622 cache,
3623 );
3624 self.evaluate_expression_value_with_program_snapshot(
3625 &program_snapshot,
3626 amplitude_values,
3627 expr_slots,
3628 )
3629 },
3630 )
3631 .collect())
3632 }
3633 #[cfg(not(feature = "rayon"))]
3634 {
3635 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3636 let mut expr_slots = vec![Complex64::ZERO; slot_count];
3637 Ok(resources
3638 .caches
3639 .iter()
3640 .map(|cache| {
3641 self.fill_amplitude_values(
3642 &mut amplitude_values,
3643 &active_indices,
3644 ¶meters,
3645 cache,
3646 );
3647 self.evaluate_expression_value_with_program_snapshot(
3648 &program_snapshot,
3649 &litude_values,
3650 &mut expr_slots,
3651 )
3652 })
3653 .collect())
3654 }
3655 }
3656
3657 pub fn evaluate_local_with_active_mask(
3659 &self,
3660 parameters: &[f64],
3661 active_mask: &[bool],
3662 ) -> LadduResult<Vec<Complex64>> {
3663 let resources = self.resources.read();
3664 if active_mask.len() != resources.active.len() {
3665 return Err(LadduError::LengthMismatch {
3666 context: "active amplitude mask".to_string(),
3667 expected: resources.active.len(),
3668 actual: active_mask.len(),
3669 });
3670 }
3671 let parameters = resources.parameter_map.assemble(parameters)?;
3672 let amplitude_len = self.amplitude_use_sites.len();
3673 let active_indices = active_mask
3674 .iter()
3675 .enumerate()
3676 .filter_map(|(index, &active)| if active { Some(index) } else { None })
3677 .collect::<Vec<_>>();
3678 let program_snapshot =
3679 self.expression_value_program_snapshot_for_active_mask(active_mask)?;
3680 let slot_count = self.expression_value_program_snapshot_slot_count(&program_snapshot);
3681 #[cfg(feature = "rayon")]
3682 {
3683 Ok(resources
3684 .caches
3685 .par_iter()
3686 .map_init(
3687 || {
3688 (
3689 vec![Complex64::ZERO; amplitude_len],
3690 vec![Complex64::ZERO; slot_count],
3691 )
3692 },
3693 |(amplitude_values, expr_slots), cache| {
3694 self.fill_amplitude_values(
3695 amplitude_values,
3696 &active_indices,
3697 ¶meters,
3698 cache,
3699 );
3700 self.evaluate_expression_value_with_program_snapshot(
3701 &program_snapshot,
3702 amplitude_values,
3703 expr_slots,
3704 )
3705 },
3706 )
3707 .collect())
3708 }
3709 #[cfg(not(feature = "rayon"))]
3710 {
3711 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3712 let mut expr_slots = vec![Complex64::ZERO; slot_count];
3713 Ok(resources
3714 .caches
3715 .iter()
3716 .map(|cache| {
3717 self.fill_amplitude_values(
3718 &mut amplitude_values,
3719 &active_indices,
3720 ¶meters,
3721 cache,
3722 );
3723 self.evaluate_expression_value_with_program_snapshot(
3724 &program_snapshot,
3725 &litude_values,
3726 &mut expr_slots,
3727 )
3728 })
3729 .collect())
3730 }
3731 }
3732
3733 #[cfg(feature = "execution-context-prototype")]
3735 pub fn evaluate_local_with_ctx(
3736 &self,
3737 parameters: &[f64],
3738 execution_context: &ExecutionContext,
3739 ) -> Vec<Complex64> {
3740 let resources = self.resources.read();
3741 let parameters = resources
3742 .parameter_map
3743 .assemble(parameters)
3744 .expect("parameter slice must match evaluator resources");
3745 let amplitude_len = self.amplitude_use_sites.len();
3746 let active_indices = resources.active_indices().to_vec();
3747 let slot_count = self.expression_value_slot_count();
3748 let program_snapshot = self.expression_value_program_snapshot();
3749 #[cfg(feature = "rayon")]
3750 {
3751 if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
3752 return execution_context.install(|| {
3753 resources
3754 .caches
3755 .par_iter()
3756 .map_init(
3757 || {
3758 (
3759 vec![Complex64::ZERO; amplitude_len],
3760 vec![Complex64::ZERO; slot_count],
3761 )
3762 },
3763 |(amplitude_values, expr_slots), cache| {
3764 self.fill_amplitude_values(
3765 amplitude_values,
3766 &active_indices,
3767 ¶meters,
3768 cache,
3769 );
3770 self.evaluate_expression_value_with_program_snapshot(
3771 &program_snapshot,
3772 amplitude_values,
3773 expr_slots,
3774 )
3775 },
3776 )
3777 .collect()
3778 });
3779 }
3780 }
3781 execution_context.with_scratch(|scratch| {
3782 let (amplitude_values, expr_slots) =
3783 scratch.reserve_value_workspaces(amplitude_len, slot_count);
3784 resources
3785 .caches
3786 .iter()
3787 .map(|cache| {
3788 self.fill_amplitude_values(
3789 amplitude_values,
3790 &active_indices,
3791 ¶meters,
3792 cache,
3793 );
3794 self.evaluate_expression_value_with_program_snapshot(
3795 &program_snapshot,
3796 amplitude_values,
3797 expr_slots,
3798 )
3799 })
3800 .collect()
3801 })
3802 }
3803
3804 #[cfg(feature = "mpi")]
3812 fn evaluate_mpi(
3813 &self,
3814 parameters: &[f64],
3815 world: &SimpleCommunicator,
3816 ) -> LadduResult<Vec<Complex64>> {
3817 let local_evaluation = self.evaluate_local(parameters)?;
3818 let n_events = self.dataset.n_events();
3819 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3820 let (counts, displs) = world.get_counts_displs(n_events);
3821 {
3822 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3825 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3826 }
3827 Ok(buffer)
3828 }
3829
3830 #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
3831 fn evaluate_mpi_with_ctx(
3832 &self,
3833 parameters: &[f64],
3834 world: &SimpleCommunicator,
3835 execution_context: &ExecutionContext,
3836 ) -> Vec<Complex64> {
3837 let local_evaluation = self.evaluate_local_with_ctx(parameters, execution_context);
3838 let n_events = self.dataset.n_events();
3839 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
3840 let (counts, displs) = world.get_counts_displs(n_events);
3841 {
3842 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
3845 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
3846 }
3847 buffer
3848 }
3849
3850 pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
3853 #[cfg(feature = "mpi")]
3854 {
3855 if let Some(world) = crate::mpi::get_world() {
3856 return self.evaluate_mpi(parameters, &world);
3857 }
3858 }
3859 self.evaluate_local(parameters)
3860 }
3861
3862 #[cfg(feature = "execution-context-prototype")]
3868 pub fn evaluate_with_ctx(
3869 &self,
3870 parameters: &[f64],
3871 execution_context: &ExecutionContext,
3872 ) -> Vec<Complex64> {
3873 #[cfg(feature = "mpi")]
3874 {
3875 if let Some(world) = crate::mpi::get_world() {
3876 return self.evaluate_mpi_with_ctx(parameters, &world, execution_context);
3877 }
3878 }
3879 self.evaluate_local_with_ctx(parameters, execution_context)
3880 }
3881
3882 pub fn evaluate_batch_local(
3885 &self,
3886 parameters: &[f64],
3887 indices: &[usize],
3888 ) -> LadduResult<Vec<Complex64>> {
3889 let resources = self.resources.read();
3890 let parameters = resources.parameter_map.assemble(parameters)?;
3891 let amplitude_len = self.amplitude_use_sites.len();
3892 let active_indices = resources.active_indices().to_vec();
3893 let slot_count = self.expression_value_slot_count();
3894 let program_snapshot = self.expression_value_program_snapshot();
3895 #[cfg(feature = "rayon")]
3896 {
3897 Ok(indices
3898 .par_iter()
3899 .map_init(
3900 || {
3901 (
3902 vec![Complex64::ZERO; amplitude_len],
3903 vec![Complex64::ZERO; slot_count],
3904 )
3905 },
3906 |(amplitude_values, expr_slots), &idx| {
3907 let cache = &resources.caches[idx];
3908 self.fill_amplitude_values(
3909 amplitude_values,
3910 &active_indices,
3911 ¶meters,
3912 cache,
3913 );
3914 self.evaluate_expression_value_with_program_snapshot(
3915 &program_snapshot,
3916 amplitude_values,
3917 expr_slots,
3918 )
3919 },
3920 )
3921 .collect())
3922 }
3923 #[cfg(not(feature = "rayon"))]
3924 {
3925 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3926 let mut expr_slots = vec![Complex64::ZERO; slot_count];
3927 Ok(indices
3928 .iter()
3929 .map(|&idx| {
3930 let cache = &resources.caches[idx];
3931 self.fill_amplitude_values(
3932 &mut amplitude_values,
3933 &active_indices,
3934 ¶meters,
3935 cache,
3936 );
3937 self.evaluate_expression_value_with_program_snapshot(
3938 &program_snapshot,
3939 &litude_values,
3940 &mut expr_slots,
3941 )
3942 })
3943 .collect())
3944 }
3945 }
3946
3947 #[cfg(feature = "mpi")]
3950 fn evaluate_batch_mpi(
3951 &self,
3952 parameters: &[f64],
3953 indices: &[usize],
3954 world: &SimpleCommunicator,
3955 ) -> LadduResult<Vec<Complex64>> {
3956 let total = self.dataset.n_events();
3957 let locals = world.locals_from_globals(indices, total);
3958 let local_evaluation = self.evaluate_batch_local(parameters, &locals)?;
3959 Ok(world.all_gather_batched_partitioned(&local_evaluation, indices, total, None))
3960 }
3961
3962 pub fn evaluate_batch(
3965 &self,
3966 parameters: &[f64],
3967 indices: &[usize],
3968 ) -> LadduResult<Vec<Complex64>> {
3969 #[cfg(feature = "mpi")]
3970 {
3971 if let Some(world) = crate::mpi::get_world() {
3972 return self.evaluate_batch_mpi(parameters, indices, &world);
3973 }
3974 }
3975 self.evaluate_batch_local(parameters, indices)
3976 }
3977
3978 pub fn evaluate_gradient_local(
3986 &self,
3987 parameters: &[f64],
3988 ) -> LadduResult<Vec<DVector<Complex64>>> {
3989 let resources = self.resources.read();
3990 let parameters = resources.parameter_map.assemble(parameters)?;
3991 let amplitude_len = self.amplitude_use_sites.len();
3992 let grad_dim = parameters.len();
3993 let active_indices = resources.active_indices().to_vec();
3994 let lowered_runtime = self.lowered_runtime();
3995 let gradient_program = lowered_runtime.gradient_program();
3996 let slot_count = self.expression_gradient_slot_count();
3997 #[cfg(feature = "rayon")]
3998 {
3999 Ok(resources
4000 .caches
4001 .par_iter()
4002 .map_init(
4003 || {
4004 (
4005 vec![Complex64::ZERO; amplitude_len],
4006 vec![DVector::zeros(grad_dim); amplitude_len],
4007 vec![Complex64::ZERO; slot_count],
4008 vec![Complex64::ZERO; slot_count * grad_dim],
4009 )
4010 },
4011 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
4012 self.fill_amplitude_values_and_gradients(
4013 amplitude_values,
4014 gradient_values,
4015 &active_indices,
4016 &resources.active,
4017 ¶meters,
4018 cache,
4019 );
4020 gradient_program.evaluate_gradient_into_flat(
4021 amplitude_values,
4022 gradient_values,
4023 value_slots,
4024 gradient_slots,
4025 grad_dim,
4026 )
4027 },
4028 )
4029 .collect())
4030 }
4031 #[cfg(not(feature = "rayon"))]
4032 {
4033 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4034 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4035 let mut value_slots = vec![Complex64::ZERO; slot_count];
4036 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4037 Ok(resources
4038 .caches
4039 .iter()
4040 .map(|cache| {
4041 self.fill_amplitude_values_and_gradients(
4042 &mut amplitude_values,
4043 &mut gradient_values,
4044 &active_indices,
4045 &resources.active,
4046 ¶meters,
4047 cache,
4048 );
4049 gradient_program.evaluate_gradient_into_flat(
4050 &litude_values,
4051 &gradient_values,
4052 &mut value_slots,
4053 &mut gradient_slots,
4054 grad_dim,
4055 )
4056 })
4057 .collect())
4058 }
4059 }
4060
4061 #[cfg(feature = "execution-context-prototype")]
4063 pub fn evaluate_gradient_local_with_ctx(
4064 &self,
4065 parameters: &[f64],
4066 execution_context: &ExecutionContext,
4067 ) -> Vec<DVector<Complex64>> {
4068 let resources = self.resources.read();
4069 let parameters = resources
4070 .parameter_map
4071 .assemble(parameters)
4072 .expect("parameter slice must match evaluator resources");
4073 let amplitude_len = self.amplitude_use_sites.len();
4074 let grad_dim = parameters.len();
4075 let active_indices = resources.active_indices().to_vec();
4076 let slot_count = self.expression_slot_count();
4077 #[cfg(feature = "rayon")]
4078 {
4079 if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
4080 return execution_context.install(|| {
4081 resources
4082 .caches
4083 .par_iter()
4084 .map_init(
4085 || {
4086 (
4087 vec![Complex64::ZERO; amplitude_len],
4088 vec![DVector::zeros(grad_dim); amplitude_len],
4089 vec![Complex64::ZERO; slot_count],
4090 vec![DVector::zeros(grad_dim); slot_count],
4091 )
4092 },
4093 |(amplitude_values, gradient_values, value_slots, gradient_slots),
4094 cache| {
4095 self.evaluate_cache_gradient_with_scratch(
4096 amplitude_values,
4097 gradient_values,
4098 value_slots,
4099 gradient_slots,
4100 &active_indices,
4101 &resources.active,
4102 ¶meters,
4103 cache,
4104 )
4105 },
4106 )
4107 .collect()
4108 });
4109 }
4110 }
4111 execution_context.with_scratch(|scratch| {
4112 let (amplitude_values, value_slots, gradient_values, gradient_slots) =
4113 scratch.reserve_gradient_workspaces(amplitude_len, slot_count, grad_dim);
4114 resources
4115 .caches
4116 .iter()
4117 .map(|cache| {
4118 self.evaluate_cache_gradient_with_scratch(
4119 amplitude_values,
4120 gradient_values,
4121 value_slots,
4122 gradient_slots,
4123 &active_indices,
4124 &resources.active,
4125 ¶meters,
4126 cache,
4127 )
4128 })
4129 .collect()
4130 })
4131 }
4132
4133 #[cfg(feature = "mpi")]
4141 fn evaluate_gradient_mpi(
4142 &self,
4143 parameters: &[f64],
4144 world: &SimpleCommunicator,
4145 ) -> LadduResult<Vec<DVector<Complex64>>> {
4146 let local_evaluation = self.evaluate_gradient_local(parameters)?;
4147 let n_events = self.dataset.n_events();
4148 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
4149 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
4150 {
4151 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4154 world.all_gather_varcount_into(
4155 &local_evaluation
4156 .iter()
4157 .flat_map(|v| v.data.as_vec())
4158 .copied()
4159 .collect::<Vec<_>>(),
4160 &mut partitioned_buffer,
4161 );
4162 }
4163 Ok(buffer
4164 .chunks(parameters.len())
4165 .map(DVector::from_row_slice)
4166 .collect())
4167 }
4168
4169 #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
4170 fn evaluate_gradient_mpi_with_ctx(
4171 &self,
4172 parameters: &[f64],
4173 world: &SimpleCommunicator,
4174 execution_context: &ExecutionContext,
4175 ) -> Vec<DVector<Complex64>> {
4176 let local_evaluation = self.evaluate_gradient_local_with_ctx(parameters, execution_context);
4177 let n_events = self.dataset.n_events();
4178 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
4179 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
4180 {
4181 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4184 world.all_gather_varcount_into(
4185 &local_evaluation
4186 .iter()
4187 .flat_map(|v| v.data.as_vec())
4188 .copied()
4189 .collect::<Vec<_>>(),
4190 &mut partitioned_buffer,
4191 );
4192 }
4193 buffer
4194 .chunks(parameters.len())
4195 .map(DVector::from_row_slice)
4196 .collect()
4197 }
4198
4199 pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<Vec<DVector<Complex64>>> {
4202 #[cfg(feature = "mpi")]
4203 {
4204 if let Some(world) = crate::mpi::get_world() {
4205 return self.evaluate_gradient_mpi(parameters, &world);
4206 }
4207 }
4208 self.evaluate_gradient_local(parameters)
4209 }
4210
4211 #[cfg(feature = "execution-context-prototype")]
4217 pub fn evaluate_gradient_with_ctx(
4218 &self,
4219 parameters: &[f64],
4220 execution_context: &ExecutionContext,
4221 ) -> Vec<DVector<Complex64>> {
4222 #[cfg(feature = "mpi")]
4223 {
4224 if let Some(world) = crate::mpi::get_world() {
4225 return self.evaluate_gradient_mpi_with_ctx(parameters, &world, execution_context);
4226 }
4227 }
4228 self.evaluate_gradient_local_with_ctx(parameters, execution_context)
4229 }
4230
4231 pub fn evaluate_gradient_batch_local(
4234 &self,
4235 parameters: &[f64],
4236 indices: &[usize],
4237 ) -> LadduResult<Vec<DVector<Complex64>>> {
4238 let resources = self.resources.read();
4239 let parameters = resources.parameter_map.assemble(parameters)?;
4240 let amplitude_len = self.amplitude_use_sites.len();
4241 let grad_dim = parameters.len();
4242 let active_indices = resources.active_indices().to_vec();
4243 let lowered_runtime = self.lowered_runtime();
4244 let gradient_program = lowered_runtime.gradient_program();
4245 let slot_count = self.expression_gradient_slot_count();
4246 #[cfg(feature = "rayon")]
4247 {
4248 Ok(indices
4249 .par_iter()
4250 .map_init(
4251 || {
4252 (
4253 vec![Complex64::ZERO; amplitude_len],
4254 vec![DVector::zeros(grad_dim); amplitude_len],
4255 vec![Complex64::ZERO; slot_count],
4256 vec![Complex64::ZERO; slot_count * grad_dim],
4257 )
4258 },
4259 |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
4260 let cache = &resources.caches[idx];
4261 self.fill_amplitude_values_and_gradients(
4262 amplitude_values,
4263 gradient_values,
4264 &active_indices,
4265 &resources.active,
4266 ¶meters,
4267 cache,
4268 );
4269 gradient_program.evaluate_gradient_into_flat(
4270 amplitude_values,
4271 gradient_values,
4272 value_slots,
4273 gradient_slots,
4274 grad_dim,
4275 )
4276 },
4277 )
4278 .collect())
4279 }
4280 #[cfg(not(feature = "rayon"))]
4281 {
4282 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4283 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4284 let mut value_slots = vec![Complex64::ZERO; slot_count];
4285 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4286 Ok(indices
4287 .iter()
4288 .map(|&idx| {
4289 let cache = &resources.caches[idx];
4290 self.fill_amplitude_values_and_gradients(
4291 &mut amplitude_values,
4292 &mut gradient_values,
4293 &active_indices,
4294 &resources.active,
4295 ¶meters,
4296 cache,
4297 );
4298 gradient_program.evaluate_gradient_into_flat(
4299 &litude_values,
4300 &gradient_values,
4301 &mut value_slots,
4302 &mut gradient_slots,
4303 grad_dim,
4304 )
4305 })
4306 .collect())
4307 }
4308 }
4309
4310 #[cfg(feature = "mpi")]
4313 fn evaluate_gradient_batch_mpi(
4314 &self,
4315 parameters: &[f64],
4316 indices: &[usize],
4317 world: &SimpleCommunicator,
4318 ) -> LadduResult<Vec<DVector<Complex64>>> {
4319 let total = self.dataset.n_events();
4320 let locals = world.locals_from_globals(indices, total);
4321 let flattened_local_evaluation = self
4322 .evaluate_gradient_batch_local(parameters, &locals)?
4323 .iter()
4324 .flat_map(|g| g.data.as_vec().to_vec())
4325 .collect::<Vec<Complex64>>();
4326 Ok(world
4327 .all_gather_batched_partitioned(
4328 &flattened_local_evaluation,
4329 indices,
4330 total,
4331 Some(parameters.len()),
4332 )
4333 .chunks(parameters.len())
4334 .map(DVector::from_row_slice)
4335 .collect())
4336 }
4337
4338 pub fn evaluate_gradient_batch(
4342 &self,
4343 parameters: &[f64],
4344 indices: &[usize],
4345 ) -> LadduResult<Vec<DVector<Complex64>>> {
4346 #[cfg(feature = "mpi")]
4347 {
4348 if let Some(world) = crate::mpi::get_world() {
4349 return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
4350 }
4351 }
4352 self.evaluate_gradient_batch_local(parameters, indices)
4353 }
4354
4355 pub fn evaluate_with_gradient_local(
4357 &self,
4358 parameters: &[f64],
4359 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
4360 let resources = self.resources.read();
4361 let parameters = resources.parameter_map.assemble(parameters)?;
4362 let amplitude_len = self.amplitude_use_sites.len();
4363 let grad_dim = parameters.len();
4364 let active_indices = resources.active_indices().to_vec();
4365 let lowered_runtime = self.lowered_runtime();
4366 let value_gradient_program = lowered_runtime.value_gradient_program();
4367 let slot_count = self.expression_value_gradient_slot_count();
4368 #[cfg(feature = "rayon")]
4369 {
4370 Ok(resources
4371 .caches
4372 .par_iter()
4373 .map_init(
4374 || {
4375 (
4376 vec![Complex64::ZERO; amplitude_len],
4377 vec![DVector::zeros(grad_dim); amplitude_len],
4378 vec![Complex64::ZERO; slot_count],
4379 vec![Complex64::ZERO; slot_count * grad_dim],
4380 )
4381 },
4382 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
4383 self.fill_amplitude_values_and_gradients(
4384 amplitude_values,
4385 gradient_values,
4386 &active_indices,
4387 &resources.active,
4388 ¶meters,
4389 cache,
4390 );
4391 value_gradient_program.evaluate_value_gradient_into_flat(
4392 amplitude_values,
4393 gradient_values,
4394 value_slots,
4395 gradient_slots,
4396 grad_dim,
4397 )
4398 },
4399 )
4400 .collect())
4401 }
4402 #[cfg(not(feature = "rayon"))]
4403 {
4404 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4405 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4406 let mut value_slots = vec![Complex64::ZERO; slot_count];
4407 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4408 Ok(resources
4409 .caches
4410 .iter()
4411 .map(|cache| {
4412 self.fill_amplitude_values_and_gradients(
4413 &mut amplitude_values,
4414 &mut gradient_values,
4415 &active_indices,
4416 &resources.active,
4417 ¶meters,
4418 cache,
4419 );
4420 value_gradient_program.evaluate_value_gradient_into_flat(
4421 &litude_values,
4422 &gradient_values,
4423 &mut value_slots,
4424 &mut gradient_slots,
4425 grad_dim,
4426 )
4427 })
4428 .collect())
4429 }
4430 }
4431
4432 pub fn evaluate_with_gradient_local_with_active_mask(
4434 &self,
4435 parameters: &[f64],
4436 active_mask: &[bool],
4437 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
4438 let resources = self.resources.read();
4439 if active_mask.len() != resources.active.len() {
4440 return Err(LadduError::LengthMismatch {
4441 context: "active amplitude mask".to_string(),
4442 expected: resources.active.len(),
4443 actual: active_mask.len(),
4444 });
4445 }
4446 let parameters = resources.parameter_map.assemble(parameters)?;
4447 let amplitude_len = self.amplitude_use_sites.len();
4448 let grad_dim = parameters.len();
4449 let active_indices = active_mask
4450 .iter()
4451 .enumerate()
4452 .filter_map(|(index, &active)| if active { Some(index) } else { None })
4453 .collect::<Vec<_>>();
4454 let lowered_runtime = self.lower_expression_runtime_for_active_mask(active_mask)?;
4455 let slot_count = lowered_runtime.value_gradient_program().scratch_slots();
4456 #[cfg(feature = "rayon")]
4457 {
4458 Ok(resources
4459 .caches
4460 .par_iter()
4461 .map_init(
4462 || {
4463 (
4464 vec![Complex64::ZERO; amplitude_len],
4465 vec![DVector::zeros(grad_dim); amplitude_len],
4466 vec![Complex64::ZERO; slot_count],
4467 vec![Complex64::ZERO; slot_count * grad_dim],
4468 )
4469 },
4470 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
4471 self.fill_amplitude_values_and_gradients(
4472 amplitude_values,
4473 gradient_values,
4474 &active_indices,
4475 active_mask,
4476 ¶meters,
4477 cache,
4478 );
4479 lowered_runtime
4480 .value_gradient_program()
4481 .evaluate_value_gradient_into_flat(
4482 amplitude_values,
4483 gradient_values,
4484 value_slots,
4485 gradient_slots,
4486 grad_dim,
4487 )
4488 },
4489 )
4490 .collect())
4491 }
4492 #[cfg(not(feature = "rayon"))]
4493 {
4494 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4495 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4496 let mut value_slots = vec![Complex64::ZERO; slot_count];
4497 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4498 Ok(resources
4499 .caches
4500 .iter()
4501 .map(|cache| {
4502 self.fill_amplitude_values_and_gradients(
4503 &mut amplitude_values,
4504 &mut gradient_values,
4505 &active_indices,
4506 active_mask,
4507 ¶meters,
4508 cache,
4509 );
4510 lowered_runtime
4511 .value_gradient_program()
4512 .evaluate_value_gradient_into_flat(
4513 &litude_values,
4514 &gradient_values,
4515 &mut value_slots,
4516 &mut gradient_slots,
4517 grad_dim,
4518 )
4519 })
4520 .collect())
4521 }
4522 }
4523
4524 pub fn evaluate_with_gradient_batch_local(
4526 &self,
4527 parameters: &[f64],
4528 indices: &[usize],
4529 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
4530 let resources = self.resources.read();
4531 let parameters = resources.parameter_map.assemble(parameters)?;
4532 let amplitude_len = self.amplitude_use_sites.len();
4533 let grad_dim = parameters.len();
4534 let active_indices = resources.active_indices().to_vec();
4535 let lowered_runtime = self.lowered_runtime();
4536 let value_gradient_program = lowered_runtime.value_gradient_program();
4537 let slot_count = self.expression_value_gradient_slot_count();
4538 #[cfg(feature = "rayon")]
4539 {
4540 Ok(indices
4541 .par_iter()
4542 .map_init(
4543 || {
4544 (
4545 vec![Complex64::ZERO; amplitude_len],
4546 vec![DVector::zeros(grad_dim); amplitude_len],
4547 vec![Complex64::ZERO; slot_count],
4548 vec![Complex64::ZERO; slot_count * grad_dim],
4549 )
4550 },
4551 |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
4552 let cache = &resources.caches[idx];
4553 self.fill_amplitude_values_and_gradients(
4554 amplitude_values,
4555 gradient_values,
4556 &active_indices,
4557 &resources.active,
4558 ¶meters,
4559 cache,
4560 );
4561 value_gradient_program.evaluate_value_gradient_into_flat(
4562 amplitude_values,
4563 gradient_values,
4564 value_slots,
4565 gradient_slots,
4566 grad_dim,
4567 )
4568 },
4569 )
4570 .collect())
4571 }
4572 #[cfg(not(feature = "rayon"))]
4573 {
4574 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4575 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4576 let mut value_slots = vec![Complex64::ZERO; slot_count];
4577 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4578 Ok(indices
4579 .iter()
4580 .map(|&idx| {
4581 let cache = &resources.caches[idx];
4582 self.fill_amplitude_values_and_gradients(
4583 &mut amplitude_values,
4584 &mut gradient_values,
4585 &active_indices,
4586 &resources.active,
4587 ¶meters,
4588 cache,
4589 );
4590 value_gradient_program.evaluate_value_gradient_into_flat(
4591 &litude_values,
4592 &gradient_values,
4593 &mut value_slots,
4594 &mut gradient_slots,
4595 grad_dim,
4596 )
4597 })
4598 .collect())
4599 }
4600 }
4601}
4602
4603#[cfg(test)]
4604mod tests {
4605 use approx::assert_relative_eq;
4606 #[cfg(feature = "mpi")]
4607 use mpi_test::mpi_test;
4608 use serde::{Deserialize, Serialize};
4609
4610 use super::*;
4611 use crate::{
4612 amplitude::{AmplitudeID, Tags, TestAmplitude},
4613 data::{test_dataset, test_event, DatasetMetadata, Event, EventData},
4614 parameter,
4615 parameters::Parameter,
4616 resources::{Cache, ParameterID, Parameters, Resources, ScalarID},
4617 vectors::Vec4,
4618 };
4619
4620 #[derive(Clone, Serialize, Deserialize)]
4621 pub struct ComplexScalar {
4622 name: String,
4623 re: Parameter,
4624 pid_re: ParameterID,
4625 im: Parameter,
4626 pid_im: ParameterID,
4627 }
4628
4629 impl ComplexScalar {
4630 #[allow(clippy::new_ret_no_self)]
4631 pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
4632 Self {
4633 name: name.to_string(),
4634 re,
4635 pid_re: Default::default(),
4636 im,
4637 pid_im: Default::default(),
4638 }
4639 .into_expression()
4640 }
4641 }
4642
4643 #[typetag::serde]
4644 impl Amplitude for ComplexScalar {
4645 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4646 self.pid_re = resources.register_parameter(&self.re)?;
4647 self.pid_im = resources.register_parameter(&self.im)?;
4648 resources.register_amplitude(&self.name)
4649 }
4650
4651 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4652 Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
4653 }
4654
4655 fn compute_gradient(
4656 &self,
4657 parameters: &Parameters,
4658 _cache: &Cache,
4659 gradient: &mut DVector<Complex64>,
4660 ) {
4661 if let Some(ind) = parameters.free_index(self.pid_re) {
4662 gradient[ind] = Complex64::ONE;
4663 }
4664 if let Some(ind) = parameters.free_index(self.pid_im) {
4665 gradient[ind] = Complex64::I;
4666 }
4667 }
4668 }
4669
4670 #[derive(Clone, Serialize, Deserialize)]
4671 pub struct ParameterOnlyScalar {
4672 name: String,
4673 value: Parameter,
4674 pid: ParameterID,
4675 }
4676
4677 impl ParameterOnlyScalar {
4678 #[allow(clippy::new_ret_no_self)]
4679 pub fn new(name: &str, value: Parameter) -> LadduResult<Expression> {
4680 Self {
4681 name: name.to_string(),
4682 value,
4683 pid: Default::default(),
4684 }
4685 .into_expression()
4686 }
4687 }
4688
4689 #[typetag::serde]
4690 impl Amplitude for ParameterOnlyScalar {
4691 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4692 self.pid = resources.register_parameter(&self.value)?;
4693 resources.register_amplitude(&self.name)
4694 }
4695
4696 fn dependence_hint(&self) -> ExpressionDependence {
4697 ExpressionDependence::ParameterOnly
4698 }
4699
4700 fn real_valued_hint(&self) -> bool {
4701 true
4702 }
4703
4704 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
4705 Complex64::new(parameters.get(self.pid), 0.0)
4706 }
4707 }
4708
4709 #[derive(Clone, Serialize, Deserialize)]
4710 pub struct CacheOnlyScalar {
4711 name: String,
4712 beam_energy: ScalarID,
4713 }
4714
4715 impl CacheOnlyScalar {
4716 #[allow(clippy::new_ret_no_self)]
4717 pub fn new(name: &str) -> LadduResult<Expression> {
4718 Self {
4719 name: name.to_string(),
4720 beam_energy: Default::default(),
4721 }
4722 .into_expression()
4723 }
4724 }
4725
4726 #[typetag::serde]
4727 impl Amplitude for CacheOnlyScalar {
4728 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
4729 self.beam_energy =
4730 resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
4731 resources.register_amplitude(&self.name)
4732 }
4733
4734 fn dependence_hint(&self) -> ExpressionDependence {
4735 ExpressionDependence::CacheOnly
4736 }
4737
4738 fn real_valued_hint(&self) -> bool {
4739 true
4740 }
4741
4742 fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {
4743 cache.store_scalar(self.beam_energy, event.p4_at(0).e());
4744 }
4745
4746 fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
4747 Complex64::new(cache.get_scalar(self.beam_energy), 0.0)
4748 }
4749 }
4750
4751 #[derive(Clone, Copy)]
4752 enum DeterministicFixtureKind {
4753 Separable,
4754 Partial,
4755 NonSeparable,
4756 }
4757
4758 struct DeterministicFixture {
4759 expression: Expression,
4760 dataset: Arc<Dataset>,
4761 parameters: Vec<f64>,
4762 }
4763
4764 const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
4765 const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
4766
4767 fn deterministic_fixture_dataset() -> Arc<Dataset> {
4768 let metadata = Arc::new(DatasetMetadata::default());
4769 let events = vec![
4770 Arc::new(EventData {
4771 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
4772 aux: vec![],
4773 weight: 0.5,
4774 }),
4775 Arc::new(EventData {
4776 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
4777 aux: vec![],
4778 weight: -1.25,
4779 }),
4780 Arc::new(EventData {
4781 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
4782 aux: vec![],
4783 weight: 2.0,
4784 }),
4785 Arc::new(EventData {
4786 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
4787 aux: vec![],
4788 weight: 0.75,
4789 }),
4790 ];
4791 Arc::new(Dataset::new_with_metadata(events, metadata))
4792 }
4793
4794 fn make_deterministic_fixture(kind: DeterministicFixtureKind) -> DeterministicFixture {
4795 let dataset = deterministic_fixture_dataset();
4796 match kind {
4797 DeterministicFixtureKind::Separable => {
4798 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1"))
4799 .expect("separable p1 should build");
4800 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2"))
4801 .expect("separable p2 should build");
4802 let c1 = CacheOnlyScalar::new("c1").expect("separable c1 should build");
4803 let c2 = CacheOnlyScalar::new("c2").expect("separable c2 should build");
4804 DeterministicFixture {
4805 expression: (&p1 * &c1) + &(&p2 * &c2),
4806 dataset,
4807 parameters: vec![0.4, -0.3],
4808 }
4809 }
4810 DeterministicFixtureKind::Partial => {
4811 let p =
4812 ParameterOnlyScalar::new("p", parameter!("p")).expect("partial p should build");
4813 let c = CacheOnlyScalar::new("c").expect("partial c should build");
4814 let m = TestAmplitude::new("m", parameter!("mr"), parameter!("mi"))
4815 .expect("partial m should build");
4816 DeterministicFixture {
4817 expression: (&p * &c) + &m,
4818 dataset,
4819 parameters: vec![0.55, 0.2, -0.15],
4820 }
4821 }
4822 DeterministicFixtureKind::NonSeparable => {
4823 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i"))
4824 .expect("non-separable m1 should build");
4825 let m2 = TestAmplitude::new("m2", parameter!("m2r"), parameter!("m2i"))
4826 .expect("non-separable m2 should build");
4827 DeterministicFixture {
4828 expression: &m1 * &m2,
4829 dataset,
4830 parameters: vec![0.25, -0.4, 0.6, 0.1],
4831 }
4832 }
4833 }
4834 }
4835
4836 fn assert_weighted_sum_matches_eventwise_baseline(fixture: &DeterministicFixture) {
4837 let evaluator = fixture
4838 .expression
4839 .load(&fixture.dataset)
4840 .expect("fixture evaluator should load");
4841 let expected_value = evaluator
4842 .evaluate_local(&fixture.parameters)
4843 .expect("evaluation should succeed")
4844 .iter()
4845 .zip(fixture.dataset.weights_local().iter())
4846 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
4847 let expected_gradient = evaluator
4848 .evaluate_gradient_local(&fixture.parameters)
4849 .expect("evaluation should succeed")
4850 .iter()
4851 .zip(fixture.dataset.weights_local().iter())
4852 .fold(
4853 DVector::zeros(fixture.parameters.len()),
4854 |mut accum, (gradient, event)| {
4855 accum += gradient.map(|value| value.re).scale(*event);
4856 accum
4857 },
4858 );
4859 let actual_value = evaluator
4860 .evaluate_weighted_value_sum_local(&fixture.parameters)
4861 .expect("evaluation should succeed");
4862 let actual_gradient = evaluator
4863 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4864 .expect("evaluation should succeed");
4865 assert_relative_eq!(
4866 actual_value,
4867 expected_value,
4868 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4869 max_relative = DETERMINISTIC_STRICT_REL_TOL
4870 );
4871 for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
4872 assert_relative_eq!(
4873 *actual_item,
4874 *expected_item,
4875 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4876 max_relative = DETERMINISTIC_STRICT_REL_TOL
4877 );
4878 }
4879 }
4880 fn assert_mixed_normalization_components_match_combined_path(fixture: &DeterministicFixture) {
4881 let evaluator = fixture
4882 .expression
4883 .load(&fixture.dataset)
4884 .expect("fixture evaluator should load");
4885 let state = {
4886 let resources = evaluator.resources.read();
4887 evaluator.ensure_cached_integral_cache_state(&resources)
4888 }
4889 .expect("state should be available");
4890 assert!(
4891 !state.values.is_empty(),
4892 "fixture should exercise cached normalization terms"
4893 );
4894 assert!(
4895 !state.execution_sets.residual_amplitudes.is_empty(),
4896 "fixture should exercise residual normalization amplitudes"
4897 );
4898
4899 let (residual_value_sum, cached_value_sum) = evaluator
4900 .evaluate_weighted_value_sum_local_components(&fixture.parameters)
4901 .expect("evaluation should succeed");
4902 assert!(residual_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4903 assert!(cached_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
4904 let combined_value = evaluator
4905 .evaluate_weighted_value_sum_local(&fixture.parameters)
4906 .expect("evaluation should succeed");
4907 assert_relative_eq!(
4908 residual_value_sum + cached_value_sum,
4909 combined_value,
4910 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4911 max_relative = DETERMINISTIC_STRICT_REL_TOL
4912 );
4913
4914 let (residual_gradient_sum, cached_gradient_sum) = evaluator
4915 .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
4916 .expect("evaluation should succeed");
4917 let combined_gradient = evaluator
4918 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
4919 .expect("evaluation should succeed");
4920 assert!(residual_gradient_sum
4921 .iter()
4922 .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4923 assert!(cached_gradient_sum
4924 .iter()
4925 .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
4926 for ((residual_item, cached_item), combined_item) in residual_gradient_sum
4927 .iter()
4928 .zip(cached_gradient_sum.iter())
4929 .zip(combined_gradient.iter())
4930 {
4931 assert_relative_eq!(
4932 residual_item + cached_item,
4933 *combined_item,
4934 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4935 max_relative = DETERMINISTIC_STRICT_REL_TOL
4936 );
4937 }
4938 }
4939
4940 #[test]
4941 fn test_deterministic_fixture_weighted_sums_stable_across_activation_mask_toggle() {
4942 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4943 let evaluator = fixture
4944 .expression
4945 .load(&fixture.dataset)
4946 .expect("fixture evaluator should load");
4947 let original_mask = evaluator.active_mask();
4948
4949 let original_value = evaluator
4950 .evaluate_weighted_value_sum_local(&fixture.parameters)
4951 .expect("evaluation should succeed");
4952
4953 evaluator.isolate_many(&["p", "c"]);
4954 assert_ne!(evaluator.active_mask(), original_mask);
4955
4956 evaluator
4957 .set_active_mask(&original_mask)
4958 .expect("original fixture active mask should restore");
4959 assert_eq!(evaluator.active_mask(), original_mask);
4960 let actual_value = evaluator
4961 .evaluate_weighted_value_sum_local(&fixture.parameters)
4962 .expect("evaluation should succeed");
4963 assert_relative_eq!(
4964 actual_value,
4965 original_value,
4966 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
4967 max_relative = DETERMINISTIC_STRICT_REL_TOL
4968 );
4969 }
4970
4971 #[test]
4972 fn test_deterministic_fixtures_match_eventwise_weighted_sums() {
4973 let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4974 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4975 let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4976
4977 assert_weighted_sum_matches_eventwise_baseline(&separable);
4978 assert_weighted_sum_matches_eventwise_baseline(&partial);
4979 assert_weighted_sum_matches_eventwise_baseline(&non_separable);
4980 }
4981 #[test]
4982 fn test_deterministic_fixtures_cover_separable_partial_non_separable_models() {
4983 let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
4984 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
4985 let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
4986
4987 let separable_evaluator = separable
4988 .expression
4989 .load(&separable.dataset)
4990 .expect("separable evaluator should load");
4991 let partial_evaluator = partial
4992 .expression
4993 .load(&partial.dataset)
4994 .expect("partial evaluator should load");
4995 let non_separable_evaluator = non_separable
4996 .expression
4997 .load(&non_separable.dataset)
4998 .expect("non-separable evaluator should load");
4999
5000 assert_eq!(
5001 separable_evaluator
5002 .expression_precomputed_cached_integrals()
5003 .expect("integrals should be computed")
5004 .len(),
5005 2
5006 );
5007 assert_eq!(
5008 partial_evaluator
5009 .expression_precomputed_cached_integrals()
5010 .expect("integrals should be computed")
5011 .len(),
5012 1
5013 );
5014 assert!(non_separable_evaluator
5015 .expression_precomputed_cached_integrals()
5016 .expect("integrals should be computed")
5017 .is_empty());
5018 }
5019 #[test]
5020 fn test_partial_fixture_combined_normalization_components_match_total() {
5021 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5022 assert_mixed_normalization_components_match_combined_path(&partial);
5023 }
5024 #[test]
5025 fn test_non_separable_fixture_normalization_components_stay_residual_only() {
5026 let fixture = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
5027 let evaluator = fixture
5028 .expression
5029 .load(&fixture.dataset)
5030 .expect("fixture evaluator should load");
5031 let resources = evaluator.resources.read();
5032 let state = evaluator
5033 .ensure_cached_integral_cache_state(&resources)
5034 .expect("state should be available");
5035 assert!(state.values.is_empty());
5036
5037 let (residual_value_sum, cached_value_sum) = evaluator
5038 .evaluate_weighted_value_sum_local_components(&fixture.parameters)
5039 .expect("evaluation should succeed");
5040 assert_relative_eq!(
5041 cached_value_sum,
5042 0.0,
5043 epsilon = DETERMINISTIC_STRICT_ABS_TOL
5044 );
5045 assert_relative_eq!(
5046 residual_value_sum,
5047 evaluator
5048 .evaluate_weighted_value_sum_local(&fixture.parameters)
5049 .expect("evaluation should succeed"),
5050 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5051 max_relative = DETERMINISTIC_STRICT_REL_TOL
5052 );
5053
5054 let (residual_gradient_sum, cached_gradient_sum) = evaluator
5055 .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
5056 .expect("evaluation should succeed");
5057 assert!(cached_gradient_sum
5058 .iter()
5059 .all(|value| value.abs() <= DETERMINISTIC_STRICT_ABS_TOL));
5060 let combined_gradient = evaluator
5061 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
5062 .expect("evaluation should succeed");
5063 for (residual_item, combined_item) in
5064 residual_gradient_sum.iter().zip(combined_gradient.iter())
5065 {
5066 assert_relative_eq!(
5067 *residual_item,
5068 *combined_item,
5069 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5070 max_relative = DETERMINISTIC_STRICT_REL_TOL
5071 );
5072 }
5073 }
5074
5075 #[test]
5076 fn test_batch_evaluation() {
5077 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag")).unwrap();
5078 let mut event1 = test_event();
5079 event1.p4s[0].t = 10.0;
5080 let mut event2 = test_event();
5081 event2.p4s[0].t = 11.0;
5082 let mut event3 = test_event();
5083 event3.p4s[0].t = 12.0;
5084 let dataset = Arc::new(Dataset::new_with_metadata(
5085 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5086 Arc::new(DatasetMetadata::default()),
5087 ));
5088 let evaluator = expr.load(&dataset).unwrap();
5089 let result = evaluator
5090 .evaluate_batch(&[1.1, 2.2], &[0, 2])
5091 .expect("evaluation should succeed");
5092 assert_eq!(result.len(), 2);
5093 assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
5094 assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
5095 let result_grad = evaluator
5096 .evaluate_gradient_batch(&[1.1, 2.2], &[0, 2])
5097 .expect("evaluation should succeed");
5098 assert_eq!(result_grad.len(), 2);
5099 assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
5100 assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
5101 assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
5102 assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
5103 }
5104
5105 #[test]
5106 fn test_load_compiles_expression_ir_once() {
5107 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5108 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5109 .norm_sqr();
5110 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5111 let evaluator = expr.load(&dataset).unwrap();
5112 assert!(evaluator.expression_slot_count() > 0);
5113 }
5114 #[test]
5115 fn test_expression_ir_value_matches_lowered_runtime() {
5116 let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5117 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5118 * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
5119 .conj()
5120 .norm_sqr();
5121 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5122 let evaluator = expr.load(&dataset).unwrap();
5123 let resources = evaluator.resources.read();
5124 let parameters = resources
5125 .parameter_map
5126 .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
5127 .expect("parameters should assemble");
5128 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5129 evaluator.fill_amplitude_values(
5130 &mut amplitude_values,
5131 resources.active_indices(),
5132 ¶meters,
5133 &resources.caches[0],
5134 );
5135 let mut ir_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
5136 let lowered_runtime = evaluator.lowered_runtime();
5137 let lowered_program = lowered_runtime.value_program();
5138 let mut lowered_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
5139 let lowered_value =
5140 evaluator.evaluate_expression_value_with_scratch(&litude_values, &mut ir_slots);
5141 let direct_lowered_value =
5142 lowered_program.evaluate_into(&litude_values, &mut lowered_slots);
5143 let ir_value = evaluator
5144 .expression_ir()
5145 .evaluate_into(&litude_values, &mut ir_slots);
5146 assert_relative_eq!(lowered_value.re, direct_lowered_value.re);
5147 assert_relative_eq!(lowered_value.im, direct_lowered_value.im);
5148 assert_relative_eq!(lowered_value.re, ir_value.re);
5149 assert_relative_eq!(lowered_value.im, ir_value.im);
5150 }
5151 #[test]
5152 fn test_expression_ir_load_initializes_with_lowered_value_runtime() {
5153 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai"))
5154 .unwrap()
5155 .norm_sqr();
5156 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5157 let evaluator = expr.load(&dataset).unwrap();
5158 let lowered_runtime = evaluator.lowered_runtime();
5159 assert_eq!(
5160 lowered_runtime.value_program().kind(),
5161 lowered::LoweredProgramKind::Value
5162 );
5163 assert_eq!(
5164 lowered_runtime.gradient_program().kind(),
5165 lowered::LoweredProgramKind::Gradient
5166 );
5167 assert_eq!(
5168 lowered_runtime.value_gradient_program().kind(),
5169 lowered::LoweredProgramKind::ValueGradient
5170 );
5171 }
5172 #[test]
5173 fn test_expression_ir_gradient_matches_lowered_runtime() {
5174 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5175 * TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5176 .norm_sqr();
5177 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5178 let evaluator = expr.load(&dataset).unwrap();
5179 let resources = evaluator.resources.read();
5180 let parameters = resources
5181 .parameter_map
5182 .assemble(&[1.0, 0.25, -0.8, 0.5])
5183 .expect("parameters should assemble");
5184 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5185 evaluator.fill_amplitude_values(
5186 &mut amplitude_values,
5187 resources.active_indices(),
5188 ¶meters,
5189 &resources.caches[0],
5190 );
5191 let mut active_mask = vec![false; evaluator.amplitudes.len()];
5192 for &index in resources.active_indices() {
5193 active_mask[index] = true;
5194 }
5195 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5196 .map(|_| DVector::zeros(parameters.len()))
5197 .collect::<Vec<_>>();
5198 evaluator.fill_amplitude_gradients(
5199 &mut amplitude_gradients,
5200 &active_mask,
5201 ¶meters,
5202 &resources.caches[0],
5203 );
5204 let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
5205 let mut ir_gradient_slots: Vec<DVector<Complex64>> =
5206 (0..evaluator.expression_ir().node_count())
5207 .map(|_| DVector::zeros(parameters.len()))
5208 .collect();
5209 let lowered_runtime = evaluator.lowered_runtime();
5210 let lowered_program = lowered_runtime.gradient_program();
5211 let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
5212 let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
5213 .scratch_slots())
5214 .map(|_| DVector::zeros(parameters.len()))
5215 .collect();
5216 let active_gradient = evaluator.evaluate_expression_gradient_with_scratch(
5217 &litude_values,
5218 &litude_gradients,
5219 &mut ir_value_slots,
5220 &mut ir_gradient_slots,
5221 );
5222 let ir_gradient = evaluator.expression_ir().evaluate_gradient_into(
5223 &litude_values,
5224 &litude_gradients,
5225 &mut ir_value_slots,
5226 &mut ir_gradient_slots,
5227 );
5228 let lowered_gradient = lowered_program.evaluate_gradient_into(
5229 &litude_values,
5230 &litude_gradients,
5231 &mut lowered_value_slots,
5232 &mut lowered_gradient_slots,
5233 );
5234 for (active, lowered) in active_gradient.iter().zip(lowered_gradient.iter()) {
5235 assert_relative_eq!(active.re, lowered.re);
5236 assert_relative_eq!(active.im, lowered.im);
5237 }
5238 for (lowered, ir) in lowered_gradient.iter().zip(ir_gradient.iter()) {
5239 assert_relative_eq!(lowered.re, ir.re);
5240 assert_relative_eq!(lowered.im, ir.im);
5241 }
5242 }
5243 #[test]
5244 fn test_expression_ir_value_gradient_matches_lowered_runtime() {
5245 let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5246 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5247 * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
5248 .norm_sqr();
5249 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5250 let evaluator = expr.load(&dataset).unwrap();
5251 let resources = evaluator.resources.read();
5252 let parameters = resources
5253 .parameter_map
5254 .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
5255 .expect("parameters should assemble");
5256 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5257 evaluator.fill_amplitude_values(
5258 &mut amplitude_values,
5259 resources.active_indices(),
5260 ¶meters,
5261 &resources.caches[0],
5262 );
5263 let mut active_mask = vec![false; evaluator.amplitudes.len()];
5264 for &index in resources.active_indices() {
5265 active_mask[index] = true;
5266 }
5267 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5268 .map(|_| DVector::zeros(parameters.len()))
5269 .collect::<Vec<_>>();
5270 evaluator.fill_amplitude_gradients(
5271 &mut amplitude_gradients,
5272 &active_mask,
5273 ¶meters,
5274 &resources.caches[0],
5275 );
5276 let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
5277 let mut ir_gradient_slots: Vec<DVector<Complex64>> =
5278 (0..evaluator.expression_ir().node_count())
5279 .map(|_| DVector::zeros(parameters.len()))
5280 .collect();
5281 let lowered_runtime = evaluator.lowered_runtime();
5282 let lowered_program = lowered_runtime.value_gradient_program();
5283 let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
5284 let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
5285 .scratch_slots())
5286 .map(|_| DVector::zeros(parameters.len()))
5287 .collect();
5288
5289 let active_value_gradient = evaluator.evaluate_expression_value_gradient_with_scratch(
5290 &litude_values,
5291 &litude_gradients,
5292 &mut ir_value_slots,
5293 &mut ir_gradient_slots,
5294 );
5295 let ir_value_gradient = evaluator.expression_ir().evaluate_value_gradient_into(
5296 &litude_values,
5297 &litude_gradients,
5298 &mut ir_value_slots,
5299 &mut ir_gradient_slots,
5300 );
5301 let lowered_value_gradient = lowered_program.evaluate_value_gradient_into(
5302 &litude_values,
5303 &litude_gradients,
5304 &mut lowered_value_slots,
5305 &mut lowered_gradient_slots,
5306 );
5307
5308 assert_relative_eq!(active_value_gradient.0.re, lowered_value_gradient.0.re);
5309 assert_relative_eq!(active_value_gradient.0.im, lowered_value_gradient.0.im);
5310 for (active, lowered) in active_value_gradient
5311 .1
5312 .iter()
5313 .zip(lowered_value_gradient.1.iter())
5314 {
5315 assert_relative_eq!(active.re, lowered.re);
5316 assert_relative_eq!(active.im, lowered.im);
5317 }
5318 assert_relative_eq!(lowered_value_gradient.0.re, ir_value_gradient.0.re);
5319 assert_relative_eq!(lowered_value_gradient.0.im, ir_value_gradient.0.im);
5320 for (lowered, ir) in lowered_value_gradient
5321 .1
5322 .iter()
5323 .zip(ir_value_gradient.1.iter())
5324 {
5325 assert_relative_eq!(lowered.re, ir.re);
5326 assert_relative_eq!(lowered.im, ir.im);
5327 }
5328 }
5329 #[test]
5330 fn test_expression_runtime_diagnostics_reports_lowered_programs() {
5331 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5332 let evaluator = fixture
5333 .expression
5334 .load(&fixture.dataset)
5335 .expect("fixture evaluator should load");
5336
5337 let diagnostics = evaluator.expression_runtime_diagnostics();
5338 assert!(diagnostics.ir_planning_enabled);
5339 assert!(diagnostics.lowered_value_program_present);
5340 assert!(diagnostics.lowered_gradient_program_present);
5341 assert!(diagnostics.lowered_value_gradient_program_present);
5342 assert!(diagnostics.residual_runtime_present);
5343 assert_eq!(
5344 diagnostics.specialization_status,
5345 Some(ExpressionSpecializationStatus {
5346 origin: ExpressionSpecializationOrigin::InitialLoad,
5347 })
5348 );
5349 }
5350 #[test]
5351 fn test_expression_runtime_diagnostics_reports_specialization_origin() {
5352 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5353 let evaluator = fixture
5354 .expression
5355 .load(&fixture.dataset)
5356 .expect("fixture evaluator should load");
5357
5358 assert_eq!(
5359 evaluator
5360 .expression_runtime_diagnostics()
5361 .specialization_status,
5362 Some(ExpressionSpecializationStatus {
5363 origin: ExpressionSpecializationOrigin::InitialLoad,
5364 })
5365 );
5366
5367 evaluator.isolate_many(&["p"]);
5368 assert_eq!(
5369 evaluator
5370 .expression_runtime_diagnostics()
5371 .specialization_status,
5372 Some(ExpressionSpecializationStatus {
5373 origin: ExpressionSpecializationOrigin::CacheMissRebuild,
5374 })
5375 );
5376 }
5377 #[test]
5378 fn test_compiled_expression_display_reports_dag_refs() {
5379 let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
5380 let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
5381 let term = &a * &b;
5382 let expr = &term + &term;
5383 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5384 let evaluator = expr.load(&dataset).unwrap();
5385
5386 let compiled = evaluator.compiled_expression();
5387 let display = compiled.to_string();
5388
5389 assert_eq!(compiled.root(), compiled.nodes().len() - 1);
5390 assert!(display.contains("#"));
5391 assert!(display.contains("+"));
5392 assert!(display.contains("×"));
5393 assert!(display.contains("a(id=0)"));
5394 assert!(display.contains("b(id=1)"));
5395 assert!(display.contains("(ref)"));
5396 }
5397
5398 #[test]
5399 fn test_expression_compiled_expression_display_reports_dag_refs_without_loading() {
5400 let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
5401 let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
5402 let term = &a * &b;
5403 let expr = &term + &term;
5404
5405 let compiled = expr.compiled_expression();
5406 let display = compiled.to_string();
5407
5408 assert_eq!(compiled.root(), compiled.nodes().len() - 1);
5409 assert!(display.contains("#"));
5410 assert!(display.contains("+"));
5411 assert!(display.contains("×"));
5412 assert!(display.contains("a(id=0)"));
5413 assert!(display.contains("b(id=1)"));
5414 assert!(display.contains("(ref)"));
5415 }
5416
5417 #[test]
5418 fn test_compiled_expression_display_uses_current_active_mask() {
5419 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5420 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
5421 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5422 let evaluator = expr.load(&dataset).unwrap();
5423 evaluator.deactivate("b");
5424
5425 let compiled = evaluator.compiled_expression().to_string();
5426
5427 assert!(compiled.contains("a(id=0)"));
5428 assert!(!compiled.contains("b(id=1)"));
5429 assert!(!compiled.contains("const 0"));
5430 assert!(!compiled.contains("+"));
5431 }
5432
5433 fn assert_compiled_single_amplitude(expr: &Expression, expected_label: &str) {
5434 let compiled = expr.compiled_expression();
5435 assert_eq!(compiled.nodes().len(), 1);
5436 assert_eq!(compiled.root(), 0);
5437 match &compiled.nodes()[0] {
5438 CompiledExpressionNode::Amplitude { index, name } => {
5439 assert_eq!(*index, 0);
5440 assert_eq!(name, expected_label);
5441 }
5442 node => panic!("expected one amplitude node, got {node:?}"),
5443 }
5444 }
5445
5446 fn assert_compiled_constant(expr: &Expression, expected: Complex64) {
5447 let compiled = expr.compiled_expression();
5448 assert_eq!(compiled.nodes().len(), 1);
5449 assert_eq!(compiled.root(), 0);
5450 match compiled.nodes()[0] {
5451 CompiledExpressionNode::Constant(value) => {
5452 assert_relative_eq!(value.re, expected.re);
5453 assert_relative_eq!(value.im, expected.im);
5454 }
5455 ref node => panic!("expected one constant node, got {node:?}"),
5456 }
5457 }
5458
5459 #[test]
5460 fn test_compiled_expression_simplifies_arithmetic_identities() {
5461 let amp = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
5462 let zero = Expression::zero();
5463 let one = Expression::one();
5464
5465 assert_compiled_single_amplitude(&(& + &zero), "a");
5466 assert_compiled_single_amplitude(&(&zero + &), "a");
5467 assert_compiled_single_amplitude(&(& - &zero), "a");
5468 assert_compiled_single_amplitude(&(& * &one), "a");
5469 assert_compiled_single_amplitude(&(&one * &), "a");
5470 assert_compiled_single_amplitude(&(& / &one), "a");
5471 assert_compiled_single_amplitude(&.pow(&one), "a");
5472 assert_compiled_single_amplitude(&.powi(1), "a");
5473 assert_compiled_single_amplitude(&.powf(1.0), "a");
5474
5475 let times_zero = & * &zero;
5476 assert_compiled_constant(×_zero, Complex64::ZERO);
5477 assert!(times_zero.parameters().contains_key("ar"));
5478 assert!(times_zero.parameters().contains_key("ai"));
5479
5480 assert_compiled_constant(&(&zero * &), Complex64::ZERO);
5481 assert_compiled_constant(&(&zero / &Expression::from(2.0)), Complex64::ZERO);
5482 assert_compiled_constant(&.powi(0), Complex64::ONE);
5483 assert_compiled_constant(
5484 &Expression::from(2.0).pow(&Expression::zero()),
5485 Complex64::ONE,
5486 );
5487 assert_compiled_constant(&Expression::from(2.0).powf(0.0), Complex64::ONE);
5488
5489 let unsafe_zero_division = (&zero / &).compiled_expression().to_string();
5490 assert!(unsafe_zero_division.contains("÷"));
5491 assert!(unsafe_zero_division.contains("a(id=0)"));
5492 }
5493
5494 #[test]
5495 fn test_compiled_expression_folds_unary_constant_functions() {
5496 assert_compiled_constant(&Expression::from(0.0).exp(), Complex64::ONE);
5497 assert_compiled_constant(&Expression::from(0.0).sin(), Complex64::ZERO);
5498 assert_compiled_constant(&Expression::from(0.0).cos(), Complex64::ONE);
5499 assert_compiled_constant(&Expression::from(1.0).log(), Complex64::ZERO);
5500 assert_compiled_constant(&Expression::from(4.0).sqrt(), Complex64::new(2.0, 0.0));
5501 assert_compiled_constant(&Expression::from(0.0).cis(), Complex64::ONE);
5502 }
5503
5504 #[test]
5505 fn test_evaluator_expression_reconstructs_expression() {
5506 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
5507 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5508 let evaluator = expr.load(&dataset).unwrap();
5509
5510 assert_eq!(
5511 evaluator.expression().compiled_expression(),
5512 expr.compiled_expression()
5513 );
5514 }
5515
5516 #[test]
5517 fn test_active_mask_override_ignores_current_ir_specialization() {
5518 let expr = ComplexScalar::new("amp", parameter!("scale"), parameter!("amp_im", 0.0))
5519 .unwrap()
5520 .norm_sqr();
5521 let dataset = Arc::new(test_dataset());
5522 let evaluator = expr.load(&dataset).unwrap();
5523 let params = vec![2.0];
5524
5525 evaluator.deactivate("amp");
5526 assert_eq!(
5527 evaluator
5528 .evaluate(¶ms)
5529 .expect("evaluation should succeed")[0],
5530 Complex64::new(0.0, 0.0)
5531 );
5532
5533 let overridden = evaluator
5534 .evaluate_local_with_active_mask(¶ms, &[true])
5535 .unwrap();
5536 assert_eq!(overridden[0], Complex64::new(4.0, 0.0));
5537
5538 let overridden_fused = evaluator
5539 .evaluate_with_gradient_local_with_active_mask(¶ms, &[true])
5540 .unwrap();
5541 assert_eq!(overridden_fused[0].0, Complex64::new(4.0, 0.0));
5542 assert_eq!(overridden_fused[0].1[0], Complex64::new(4.0, 0.0));
5543 }
5544 #[test]
5545 fn test_expression_ir_dependence_diagnostics_surface() {
5546 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5547 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5548 .norm_sqr();
5549 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5550 let evaluator = expr.load(&dataset).unwrap();
5551 let annotations = evaluator
5552 .expression_node_dependence_annotations()
5553 .expect("annotations should exist");
5554 assert_eq!(annotations.len(), evaluator.expression_ir().node_count());
5555 assert!(annotations
5556 .iter()
5557 .all(|dependence| *dependence == ExpressionDependence::Mixed));
5558 assert_eq!(
5559 evaluator
5560 .expression_root_dependence()
5561 .expect("root dependence should exist"),
5562 ExpressionDependence::Mixed
5563 );
5564 }
5565 #[test]
5566 fn test_expression_ir_default_dependence_hint_is_mixed() {
5567 let expr = ComplexScalar::new("c", parameter!("cr"), parameter!("ci")).unwrap();
5568 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5569 let evaluator = expr.load(&dataset).unwrap();
5570 assert_eq!(
5571 evaluator
5572 .expression_root_dependence()
5573 .expect("root dependence should exist"),
5574 ExpressionDependence::Mixed
5575 );
5576 }
5577 #[test]
5578 fn test_expression_ir_parameter_only_dependence_hint_propagates() {
5579 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap();
5580 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5581 let evaluator = expr.load(&dataset).unwrap();
5582 assert_eq!(
5583 evaluator
5584 .expression_root_dependence()
5585 .expect("root dependence should exist"),
5586 ExpressionDependence::ParameterOnly
5587 );
5588 }
5589 #[test]
5590 fn test_expression_ir_cache_only_dependence_hint_propagates() {
5591 let expr = CacheOnlyScalar::new("k").unwrap();
5592 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5593 let evaluator = expr.load(&dataset).unwrap();
5594 assert_eq!(
5595 evaluator
5596 .expression_root_dependence()
5597 .expect("root dependence should exist"),
5598 ExpressionDependence::CacheOnly
5599 );
5600 }
5601 #[test]
5602 fn test_expression_ir_real_valued_hint_folds_imag_projection_to_zero() {
5603 let expr = ParameterOnlyScalar::new("p", parameter!("p"))
5604 .unwrap()
5605 .imag();
5606 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5607 let evaluator = expr.load(&dataset).unwrap();
5608 let ir = evaluator.expression_ir();
5609
5610 assert!(matches!(
5611 ir.nodes()[ir.root()],
5612 ir::IrNode::Constant(value) if value == Complex64::ZERO
5613 ));
5614 assert_eq!(
5615 evaluator
5616 .evaluate(&[2.5])
5617 .expect("evaluation should succeed")[0],
5618 Complex64::ZERO
5619 );
5620 }
5621 #[test]
5622 fn test_expression_ir_real_valued_hint_simplifies_conjugation() {
5623 let expr = ParameterOnlyScalar::new("p", parameter!("p"))
5624 .unwrap()
5625 .conj();
5626 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5627 let evaluator = expr.load(&dataset).unwrap();
5628 let ir = evaluator.expression_ir();
5629
5630 assert!(matches!(ir.nodes()[ir.root()], ir::IrNode::Amp(0)));
5631 assert_eq!(
5632 evaluator
5633 .evaluate(&[2.5])
5634 .expect("evaluation should succeed")[0],
5635 Complex64::new(2.5, 0.0)
5636 );
5637 }
5638 #[test]
5639 fn test_expression_ir_dependence_warnings_surface() {
5640 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5641 + &CacheOnlyScalar::new("k").unwrap();
5642 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5643 let evaluator = expr.load(&dataset).unwrap();
5644 assert!(evaluator
5645 .expression_dependence_warnings()
5646 .expect("warnings should exist")
5647 .iter()
5648 .any(|warning| warning.contains("both ParameterOnly and CacheOnly")));
5649 }
5650 #[test]
5651 fn test_expression_ir_normalization_plan_explain_surface() {
5652 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5653 * &CacheOnlyScalar::new("k").unwrap();
5654 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5655 let evaluator = expr.load(&dataset).unwrap();
5656 let explain = evaluator
5657 .expression_normalization_plan_explain()
5658 .expect("plan should exist");
5659 assert_eq!(explain.root_dependence, ExpressionDependence::Mixed);
5660 assert_eq!(explain.separable_mul_candidate_nodes.len(), 1);
5661 assert_eq!(
5662 explain.cached_separable_nodes,
5663 explain.separable_mul_candidate_nodes
5664 );
5665 assert!(explain.residual_terms.iter().all(|index| {
5666 !explain
5667 .separable_mul_candidate_nodes
5668 .iter()
5669 .any(|candidate| candidate == index)
5670 }));
5671 }
5672 #[test]
5673 fn test_expression_ir_normalization_execution_sets_surface() {
5674 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5675 * &CacheOnlyScalar::new("k").unwrap();
5676 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5677 let evaluator = expr.load(&dataset).unwrap();
5678 let sets = evaluator
5679 .expression_normalization_execution_sets()
5680 .expect("sets should exist");
5681 assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5682 assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5683 assert!(sets.residual_amplitudes.is_empty());
5684 }
5685 #[test]
5686 fn test_expression_ir_normalization_execution_sets_partial_surface() {
5687 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5688 * &CacheOnlyScalar::new("k").unwrap())
5689 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5690 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5691 let evaluator = expr.load(&dataset).unwrap();
5692 let sets = evaluator
5693 .expression_normalization_execution_sets()
5694 .expect("sets should exist");
5695 assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
5696 assert_eq!(sets.cached_cache_amplitudes, vec![1]);
5697 assert_eq!(sets.residual_amplitudes, vec![2]);
5698 }
5699 #[test]
5700 fn test_expression_ir_precomputed_cached_integrals_at_load() {
5701 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5702 * &CacheOnlyScalar::new("k").unwrap();
5703 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5704 let evaluator = expr.load(&dataset).unwrap();
5705 let precomputed = evaluator
5706 .expression_precomputed_cached_integrals()
5707 .expect("integrals should exist");
5708 assert_eq!(precomputed.len(), 1);
5709 let cache_reference = CacheOnlyScalar::new("k_ref")
5710 .unwrap()
5711 .load(&dataset)
5712 .unwrap();
5713 let cache_values = cache_reference
5714 .evaluate_local(&[])
5715 .expect("evaluation should succeed");
5716 let expected_weighted_sum = cache_values
5717 .iter()
5718 .zip(dataset.weights_local().iter())
5719 .fold(Complex64::ZERO, |acc, (value, event)| {
5720 acc + (*value * *event)
5721 });
5722 assert_relative_eq!(
5723 precomputed[0].weighted_cache_sum.re,
5724 expected_weighted_sum.re
5725 );
5726 assert_relative_eq!(
5727 precomputed[0].weighted_cache_sum.im,
5728 expected_weighted_sum.im
5729 );
5730 }
5731 #[test]
5732 fn test_expression_ir_precomputed_cached_integrals_empty_when_non_separable() {
5733 let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5734 * &CacheOnlyScalar::new("k").unwrap();
5735 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5736 let evaluator = expr.load(&dataset).unwrap();
5737 assert!(evaluator
5738 .expression_precomputed_cached_integrals()
5739 .expect("integrals should exist")
5740 .is_empty());
5741 }
5742 #[test]
5743 fn test_expression_ir_precomputed_cached_integrals_recompute_on_activation_change() {
5744 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5745 * &CacheOnlyScalar::new("k").unwrap();
5746 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5747 let evaluator = expr.load(&dataset).unwrap();
5748 assert_eq!(
5749 evaluator
5750 .expression_precomputed_cached_integrals()
5751 .expect("integrals should exist")
5752 .len(),
5753 1
5754 );
5755
5756 evaluator.isolate_many(&["p"]);
5757 assert!(evaluator
5758 .expression_precomputed_cached_integrals()
5759 .expect("integrals should exist")
5760 .is_empty());
5761 }
5762 #[test]
5763 fn test_expression_ir_precomputed_cached_integrals_recompute_on_dataset_change() {
5764 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5765 * &CacheOnlyScalar::new("k").unwrap();
5766 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5767 let mut evaluator = expr.load(&dataset).unwrap();
5768 drop(dataset);
5769 let before = evaluator
5770 .expression_precomputed_cached_integrals()
5771 .expect("integrals should exist");
5772 assert_eq!(before.len(), 1);
5773
5774 Arc::get_mut(&mut evaluator.dataset)
5775 .expect("evaluator should own dataset Arc in this test")
5776 .clear_events_local();
5777 let after = evaluator
5778 .expression_precomputed_cached_integrals()
5779 .expect("integrals should exist");
5780 assert_eq!(after.len(), 1);
5781 assert_eq!(after[0].weighted_cache_sum, Complex64::ZERO);
5782 assert!(before[0].weighted_cache_sum != after[0].weighted_cache_sum);
5783 }
5784 #[test]
5785 fn test_expression_ir_precomputed_cached_integral_gradient_terms_scale_by_cache_integrals() {
5786 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5787 * &CacheOnlyScalar::new("k").unwrap();
5788 let dataset = Arc::new(Dataset::new(vec![
5789 Arc::new(test_event()),
5790 Arc::new(test_event()),
5791 ]));
5792 let evaluator = expr.load(&dataset).unwrap();
5793 let cached_integrals = evaluator
5794 .expression_precomputed_cached_integrals()
5795 .expect("integrals should exist");
5796 assert_eq!(cached_integrals.len(), 1);
5797 let gradient_terms = evaluator
5798 .expression_precomputed_cached_integral_gradient_terms(&[1.25])
5799 .expect("evaluation should succeed");
5800 assert_eq!(gradient_terms.len(), 1);
5801 assert_eq!(gradient_terms[0].weighted_gradient.len(), 1);
5802 assert_relative_eq!(
5803 gradient_terms[0].weighted_gradient[0].re,
5804 cached_integrals[0].weighted_cache_sum.re,
5805 epsilon = 1e-6
5806 );
5807 assert_relative_eq!(
5808 gradient_terms[0].weighted_gradient[0].im,
5809 cached_integrals[0].weighted_cache_sum.im,
5810 epsilon = 1e-6
5811 );
5812 }
5813 #[test]
5814 fn test_expression_ir_precomputed_cached_integral_gradient_terms_empty_when_not_separable() {
5815 let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
5816 * &CacheOnlyScalar::new("k").unwrap();
5817 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5818 let evaluator = expr.load(&dataset).unwrap();
5819 assert!(evaluator
5820 .expression_precomputed_cached_integral_gradient_terms(&[0.1, -0.2])
5821 .expect("evaluation should succeed")
5822 .is_empty());
5823 }
5824 #[test]
5825 fn test_expression_ir_lowered_cached_factor_programs_match_ir_cached_paths() {
5826 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5827 * &CacheOnlyScalar::new("k").unwrap())
5828 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5829 let dataset = Arc::new(test_dataset());
5830 let evaluator = expr.load(&dataset).unwrap();
5831 let resources = evaluator.resources.read();
5832 let state = evaluator
5833 .ensure_cached_integral_cache_state(&resources)
5834 .expect("state should be available");
5835 let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5836 let parameters = resources
5837 .parameter_map
5838 .assemble(&[0.55, 0.2, -0.15])
5839 .expect("parameters should assemble");
5840
5841 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5842 evaluator.fill_amplitude_values(
5843 &mut amplitude_values,
5844 &state.execution_sets.cached_parameter_amplitudes,
5845 ¶meters,
5846 &resources.caches[0],
5847 );
5848 let cached_value_ir =
5849 evaluator.evaluate_cached_weighted_value_sum_ir(&state, &litude_values);
5850 let cached_value_lowered = evaluator
5851 .evaluate_cached_weighted_value_sum_lowered(
5852 &state,
5853 lowered_artifacts.as_ref(),
5854 &litude_values,
5855 )
5856 .expect("cached value lowering should succeed");
5857 assert_relative_eq!(cached_value_lowered, cached_value_ir, epsilon = 1e-12);
5858
5859 let mut cached_parameter_mask = vec![false; evaluator.amplitudes.len()];
5860 for &index in &state.execution_sets.cached_parameter_amplitudes {
5861 cached_parameter_mask[index] = true;
5862 }
5863 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5864 .map(|_| DVector::zeros(parameters.len()))
5865 .collect::<Vec<_>>();
5866 evaluator.fill_amplitude_gradients(
5867 &mut amplitude_gradients,
5868 &cached_parameter_mask,
5869 ¶meters,
5870 &resources.caches[0],
5871 );
5872 let cached_gradient_ir = evaluator.evaluate_cached_weighted_gradient_sum_ir(
5873 &state,
5874 &litude_values,
5875 &litude_gradients,
5876 parameters.len(),
5877 );
5878 let cached_gradient_lowered = evaluator
5879 .evaluate_cached_weighted_gradient_sum_lowered(
5880 &state,
5881 lowered_artifacts.as_ref(),
5882 &litude_values,
5883 &litude_gradients,
5884 parameters.len(),
5885 )
5886 .expect("cached gradient lowering should succeed");
5887 for (lowered, ir) in cached_gradient_lowered
5888 .iter()
5889 .zip(cached_gradient_ir.iter())
5890 {
5891 assert_relative_eq!(*lowered, *ir, epsilon = 1e-12);
5892 }
5893 }
5894 #[test]
5895 fn test_expression_ir_lowered_residual_runtime_matches_zeroed_node_path() {
5896 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5897 * &CacheOnlyScalar::new("k").unwrap())
5898 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5899 let dataset = Arc::new(test_dataset());
5900 let evaluator = expr.load(&dataset).unwrap();
5901 let resources = evaluator.resources.read();
5902 let state = evaluator
5903 .ensure_cached_integral_cache_state(&resources)
5904 .expect("state should be available");
5905 let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
5906 let parameters = resources
5907 .parameter_map
5908 .assemble(&[0.55, 0.2, -0.15])
5909 .expect("parameters should assemble");
5910
5911 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5912 evaluator.fill_amplitude_values(
5913 &mut amplitude_values,
5914 &state.execution_sets.residual_amplitudes,
5915 ¶meters,
5916 &resources.caches[0],
5917 );
5918 let residual_value_ir = evaluator.evaluate_residual_value_ir(&state, &litude_values);
5919 let residual_program = lowered_artifacts
5920 .residual_runtime
5921 .as_ref()
5922 .map(|runtime| runtime.value_program())
5923 .expect("residual value lowering should succeed");
5924 let mut value_slots = vec![Complex64::ZERO; residual_program.scratch_slots()];
5925 let residual_value_lowered =
5926 residual_program.evaluate_into(&litude_values, &mut value_slots);
5927 assert_relative_eq!(
5928 residual_value_lowered.re,
5929 residual_value_ir.re,
5930 epsilon = 1e-12
5931 );
5932 assert_relative_eq!(
5933 residual_value_lowered.im,
5934 residual_value_ir.im,
5935 epsilon = 1e-12
5936 );
5937
5938 let mut residual_active_mask = vec![false; evaluator.amplitudes.len()];
5939 for &index in &state.execution_sets.residual_amplitudes {
5940 residual_active_mask[index] = true;
5941 }
5942 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
5943 .map(|_| DVector::zeros(parameters.len()))
5944 .collect::<Vec<_>>();
5945 evaluator.fill_amplitude_gradients(
5946 &mut amplitude_gradients,
5947 &residual_active_mask,
5948 ¶meters,
5949 &resources.caches[0],
5950 );
5951 let residual_gradient_ir = evaluator.evaluate_residual_gradient_ir(
5952 &state,
5953 &litude_values,
5954 &litude_gradients,
5955 parameters.len(),
5956 );
5957
5958 let program = lowered_artifacts
5959 .residual_runtime
5960 .as_ref()
5961 .map(|runtime| runtime.gradient_program())
5962 .expect("gradient lowering should succeed");
5963 let mut value_slots = vec![Complex64::ZERO; program.scratch_slots()];
5964 let mut gradient_slots = vec![Complex64::ZERO; program.scratch_slots() * parameters.len()];
5965 let residual_gradient_lowered = program.evaluate_gradient_into_flat(
5966 &litude_values,
5967 &litude_gradients,
5968 &mut value_slots,
5969 &mut gradient_slots,
5970 parameters.len(),
5971 );
5972
5973 for (lowered, ir) in residual_gradient_lowered
5974 .iter()
5975 .zip(residual_gradient_ir.iter())
5976 {
5977 assert_relative_eq!(lowered.re, ir.re, epsilon = 1e-12);
5978 assert_relative_eq!(lowered.im, ir.im, epsilon = 1e-12);
5979 }
5980 }
5981 #[test]
5982 fn test_expression_ir_reuses_lowered_artifacts_when_dataset_key_changes() {
5983 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
5984 * &CacheOnlyScalar::new("k").unwrap())
5985 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
5986 let dataset = Arc::new(test_dataset());
5987 let mut evaluator = expr.load(&dataset).unwrap();
5988 drop(dataset);
5989
5990 assert_eq!(evaluator.specialization_cache_len(), 1);
5991 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
5992
5993 evaluator.reset_expression_compile_metrics();
5994 evaluator.reset_expression_specialization_metrics();
5995
5996 Arc::get_mut(&mut evaluator.dataset)
5997 .expect("evaluator should own dataset Arc in this test")
5998 .clear_events_local();
5999
6000 let cached_integrals = evaluator
6001 .expression_precomputed_cached_integrals()
6002 .expect("integrals should exist");
6003 assert_eq!(cached_integrals.len(), 1);
6004 assert_eq!(cached_integrals[0].weighted_cache_sum, Complex64::ZERO);
6005
6006 assert_eq!(evaluator.specialization_cache_len(), 2);
6007 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
6008 assert_eq!(
6009 evaluator.expression_specialization_metrics(),
6010 ExpressionSpecializationMetrics {
6011 cache_hits: 0,
6012 cache_misses: 1,
6013 }
6014 );
6015
6016 let compile_metrics = evaluator.expression_compile_metrics();
6017 assert_eq!(compile_metrics.specialization_cache_hits, 0);
6018 assert_eq!(compile_metrics.specialization_cache_misses, 1);
6019 assert_eq!(compile_metrics.specialization_lowering_cache_hits, 1);
6020 assert_eq!(compile_metrics.specialization_lowering_cache_misses, 0);
6021 assert!(compile_metrics.specialization_ir_compile_nanos > 0);
6022 assert!(compile_metrics.specialization_cached_integrals_nanos > 0);
6023 assert_eq!(compile_metrics.specialization_lowering_nanos, 0);
6024 }
6025
6026 #[test]
6027 fn test_evaluate_weighted_gradient_sum_local_matches_eventwise_baseline() {
6028 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6029 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6030 let c1 = CacheOnlyScalar::new("c1").unwrap();
6031 let c2 = CacheOnlyScalar::new("c2").unwrap();
6032 let c3 = CacheOnlyScalar::new("c3").unwrap();
6033 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6034 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6035 let dataset = Arc::new(test_dataset());
6036 let evaluator = expr.load(&dataset).unwrap();
6037 assert_eq!(
6038 evaluator
6039 .expression_precomputed_cached_integrals()
6040 .expect("integrals should exist")
6041 .len(),
6042 2
6043 );
6044 let params = vec![0.2, -0.3, 1.1, -0.7];
6045 let expected = evaluator
6046 .evaluate_gradient_local(¶ms)
6047 .expect("evaluation should succeed")
6048 .iter()
6049 .zip(dataset.weights_local().iter())
6050 .fold(
6051 DVector::zeros(params.len()),
6052 |mut accum, (gradient, event)| {
6053 accum += gradient.map(|value| value.re).scale(*event);
6054 accum
6055 },
6056 );
6057 let actual = evaluator
6058 .evaluate_weighted_gradient_sum_local(¶ms)
6059 .expect("evaluation should succeed");
6060 for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
6061 assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
6062 }
6063 }
6064
6065 #[test]
6066 fn test_evaluate_weighted_value_sum_local_matches_eventwise_baseline() {
6067 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6068 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6069 let c1 = CacheOnlyScalar::new("c1").unwrap();
6070 let c2 = CacheOnlyScalar::new("c2").unwrap();
6071 let c3 = CacheOnlyScalar::new("c3").unwrap();
6072 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6073 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6074 let dataset = Arc::new(test_dataset());
6075 let evaluator = expr.load(&dataset).unwrap();
6076 assert_eq!(
6077 evaluator
6078 .expression_precomputed_cached_integrals()
6079 .expect("integrals should exist")
6080 .len(),
6081 2
6082 );
6083 let params = vec![0.2, -0.3, 1.1, -0.7];
6084 let expected = evaluator
6085 .evaluate_local(¶ms)
6086 .expect("evaluation should succeed")
6087 .iter()
6088 .zip(dataset.weights_local().iter())
6089 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6090 let actual = evaluator
6091 .evaluate_weighted_value_sum_local(¶ms)
6092 .expect("evaluation should succeed");
6093 assert_relative_eq!(actual, expected, epsilon = 1e-10);
6094 }
6095
6096 #[test]
6097 fn test_weighted_sums_match_hardcoded_reference_values() {
6098 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6099 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6100 let c1 = CacheOnlyScalar::new("c1").unwrap();
6101 let c2 = CacheOnlyScalar::new("c2").unwrap();
6102 let c3 = CacheOnlyScalar::new("c3").unwrap();
6103 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6104 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6105
6106 let metadata = Arc::new(DatasetMetadata::default());
6107 let dataset = Arc::new(Dataset::new_with_metadata(
6108 vec![
6109 Arc::new(EventData {
6110 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6111 aux: vec![],
6112 weight: 0.5,
6113 }),
6114 Arc::new(EventData {
6115 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6116 aux: vec![],
6117 weight: -1.25,
6118 }),
6119 Arc::new(EventData {
6120 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6121 aux: vec![],
6122 weight: 2.0,
6123 }),
6124 ],
6125 metadata,
6126 ));
6127 let evaluator = expr.load(&dataset).unwrap();
6128 let params = vec![0.7, -1.1, 0.9, -0.4];
6129
6130 let weighted_value_sum = evaluator
6131 .evaluate_weighted_value_sum_local(¶ms)
6132 .expect("evaluation should succeed");
6133 assert_relative_eq!(weighted_value_sum, 22.7725, epsilon = 1e-12);
6134
6135 let weighted_gradient_sum = evaluator
6136 .evaluate_weighted_gradient_sum_local(¶ms)
6137 .expect("evaluation should succeed");
6138 let free_parameters = evaluator
6139 .parameters()
6140 .free()
6141 .names()
6142 .into_iter()
6143 .map(|name| name.to_string())
6144 .collect::<Vec<_>>();
6145 assert_eq!(free_parameters, vec!["p1", "p2", "m1r", "m1i"]);
6146 let expected_gradient = [43.925, 7.25, 28.525, 0.0];
6147 assert_eq!(weighted_gradient_sum.len(), expected_gradient.len());
6148 for (actual, expected) in weighted_gradient_sum.iter().zip(expected_gradient.iter()) {
6149 assert_relative_eq!(*actual, *expected, epsilon = 1e-9);
6150 }
6151 }
6152 #[test]
6153 fn test_evaluate_weighted_gradient_sum_local_respects_signed_cached_terms() {
6154 let expr = Expression::one()
6155 - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6156 * &CacheOnlyScalar::new("k").unwrap());
6157 let dataset = Arc::new(test_dataset());
6158 let evaluator = expr.load(&dataset).unwrap();
6159 assert_eq!(
6160 evaluator
6161 .expression_precomputed_cached_integrals()
6162 .expect("integrals should exist")
6163 .len(),
6164 1
6165 );
6166 assert_eq!(
6167 evaluator
6168 .expression_precomputed_cached_integrals()
6169 .expect("integrals should exist")[0]
6170 .coefficient,
6171 -1
6172 );
6173 let params = vec![0.75];
6174 let expected = evaluator
6175 .evaluate_gradient_local(¶ms)
6176 .expect("evaluation should succeed")
6177 .iter()
6178 .zip(dataset.weights_local().iter())
6179 .fold(
6180 DVector::zeros(params.len()),
6181 |mut accum, (gradient, event)| {
6182 accum += gradient.map(|value| value.re).scale(*event);
6183 accum
6184 },
6185 );
6186 let actual = evaluator
6187 .evaluate_weighted_gradient_sum_local(¶ms)
6188 .expect("evaluation should succeed");
6189 for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
6190 assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
6191 }
6192 }
6193 #[test]
6194 fn test_evaluate_weighted_value_sum_local_respects_signed_cached_terms() {
6195 let expr = Expression::one()
6196 - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6197 * &CacheOnlyScalar::new("k").unwrap());
6198 let dataset = Arc::new(test_dataset());
6199 let evaluator = expr.load(&dataset).unwrap();
6200 assert_eq!(
6201 evaluator
6202 .expression_precomputed_cached_integrals()
6203 .expect("integrals should exist")
6204 .len(),
6205 1
6206 );
6207 assert_eq!(
6208 evaluator
6209 .expression_precomputed_cached_integrals()
6210 .expect("integrals should exist")[0]
6211 .coefficient,
6212 -1
6213 );
6214 let params = vec![0.75];
6215 let expected = evaluator
6216 .evaluate_local(¶ms)
6217 .expect("evaluation should succeed")
6218 .iter()
6219 .zip(dataset.weights_local().iter())
6220 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6221 let actual = evaluator
6222 .evaluate_weighted_value_sum_local(¶ms)
6223 .expect("evaluation should succeed");
6224 assert_relative_eq!(actual, expected, epsilon = 1e-10);
6225 }
6226 #[test]
6227 fn test_expression_ir_diagnostics_follow_activation_changes() {
6228 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6229 * &CacheOnlyScalar::new("k").unwrap())
6230 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6231 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6232 let evaluator = expr.load(&dataset).unwrap();
6233
6234 let all_active = evaluator
6235 .expression_normalization_plan_explain()
6236 .expect("plan should exist");
6237 assert_eq!(all_active.cached_separable_nodes.len(), 1);
6238 assert_eq!(
6239 evaluator
6240 .expression_root_dependence()
6241 .expect("root dependence should exist"),
6242 ExpressionDependence::Mixed
6243 );
6244
6245 evaluator.isolate_many(&["p"]);
6246 let param_only = evaluator
6247 .expression_normalization_plan_explain()
6248 .expect("plan should exist");
6249 assert!(param_only.cached_separable_nodes.is_empty());
6250 assert_eq!(
6251 evaluator
6252 .expression_root_dependence()
6253 .expect("root dependence should exist"),
6254 ExpressionDependence::ParameterOnly
6255 );
6256 }
6257 #[test]
6258 fn test_expression_ir_specialization_cache_reuses_prior_mask_specializations() {
6259 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6260 * &CacheOnlyScalar::new("k").unwrap())
6261 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6262 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6263 let evaluator = expr.load(&dataset).unwrap();
6264
6265 let initial_compile_metrics = evaluator.expression_compile_metrics();
6266 assert!(initial_compile_metrics.initial_ir_compile_nanos > 0);
6267 assert!(initial_compile_metrics.initial_cached_integrals_nanos > 0);
6268 assert!(initial_compile_metrics.initial_lowering_nanos > 0);
6269 assert_eq!(initial_compile_metrics.specialization_cache_hits, 0);
6270 assert_eq!(initial_compile_metrics.specialization_cache_misses, 0);
6271 assert_eq!(
6272 initial_compile_metrics.specialization_lowering_cache_hits,
6273 0
6274 );
6275 assert_eq!(
6276 initial_compile_metrics.specialization_lowering_cache_misses,
6277 1
6278 );
6279
6280 assert_eq!(evaluator.specialization_cache_len(), 1);
6281 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
6282 assert_eq!(
6283 evaluator.expression_specialization_metrics(),
6284 ExpressionSpecializationMetrics {
6285 cache_hits: 0,
6286 cache_misses: 1,
6287 }
6288 );
6289 let all_active_cached_integrals = evaluator
6290 .expression_precomputed_cached_integrals()
6291 .expect("integrals should exist");
6292
6293 evaluator.isolate_many(&["p"]);
6294 assert_eq!(evaluator.specialization_cache_len(), 2);
6295 assert_eq!(
6296 evaluator.expression_specialization_metrics(),
6297 ExpressionSpecializationMetrics {
6298 cache_hits: 0,
6299 cache_misses: 2,
6300 }
6301 );
6302 let after_cache_miss_metrics = evaluator.expression_compile_metrics();
6303 assert_eq!(after_cache_miss_metrics.specialization_cache_hits, 0);
6304 assert_eq!(after_cache_miss_metrics.specialization_cache_misses, 1);
6305 assert_eq!(
6306 after_cache_miss_metrics.specialization_lowering_cache_hits,
6307 0
6308 );
6309 assert_eq!(
6310 after_cache_miss_metrics.specialization_lowering_cache_misses,
6311 2
6312 );
6313 assert!(after_cache_miss_metrics.specialization_ir_compile_nanos > 0);
6314 assert!(after_cache_miss_metrics.specialization_cached_integrals_nanos > 0);
6315 assert!(after_cache_miss_metrics.specialization_lowering_nanos > 0);
6316 assert!(evaluator
6317 .expression_precomputed_cached_integrals()
6318 .expect("integrals should exist")
6319 .is_empty());
6320
6321 evaluator.activate_many(&["k", "m"]);
6322 assert_eq!(evaluator.specialization_cache_len(), 2);
6323 assert_eq!(
6324 evaluator.expression_specialization_metrics(),
6325 ExpressionSpecializationMetrics {
6326 cache_hits: 1,
6327 cache_misses: 2,
6328 }
6329 );
6330 assert_eq!(
6331 evaluator
6332 .expression_precomputed_cached_integrals()
6333 .expect("integrals should exist"),
6334 all_active_cached_integrals
6335 );
6336 let after_cache_hit_metrics = evaluator.expression_compile_metrics();
6337 assert_eq!(after_cache_hit_metrics.specialization_cache_hits, 1);
6338 assert_eq!(after_cache_hit_metrics.specialization_cache_misses, 1);
6339 assert_eq!(
6340 after_cache_hit_metrics.specialization_lowering_cache_hits,
6341 0
6342 );
6343 assert_eq!(
6344 after_cache_hit_metrics.specialization_lowering_cache_misses,
6345 2
6346 );
6347 assert!(after_cache_hit_metrics.specialization_cache_restore_nanos > 0);
6348 }
6349
6350 #[test]
6351 fn test_weighted_sums_match_baseline_after_activation_changes() {
6352 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6353 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6354 let c1 = CacheOnlyScalar::new("c1").unwrap();
6355 let c2 = CacheOnlyScalar::new("c2").unwrap();
6356 let c3 = CacheOnlyScalar::new("c3").unwrap();
6357 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6358 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6359 let dataset = Arc::new(test_dataset());
6360 let evaluator = expr.load(&dataset).unwrap();
6361 let params = vec![0.2, -0.3, 1.1, -0.7];
6362
6363 evaluator.isolate_many(&["p1", "c1", "m1", "c3"]);
6364
6365 let expected_value = evaluator
6366 .evaluate_local(¶ms)
6367 .expect("evaluation should succeed")
6368 .iter()
6369 .zip(dataset.weights_local().iter())
6370 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6371 assert_relative_eq!(
6372 evaluator
6373 .evaluate_weighted_value_sum_local(¶ms)
6374 .expect("evaluation should succeed"),
6375 expected_value,
6376 epsilon = 1e-10
6377 );
6378 }
6379
6380 #[test]
6381 fn test_evaluate_local_does_not_depend_on_dataset_rows() {
6382 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
6383 .unwrap()
6384 .norm_sqr();
6385 let mut event1 = test_event();
6386 event1.p4s[0].t = 7.5;
6387 let mut event2 = test_event();
6388 event2.p4s[0].t = 8.25;
6389 let mut event3 = test_event();
6390 event3.p4s[0].t = 9.0;
6391 let dataset = Arc::new(Dataset::new_with_metadata(
6392 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
6393 Arc::new(DatasetMetadata::default()),
6394 ));
6395 let mut evaluator = expr.load(&dataset).unwrap();
6396 drop(dataset);
6397 let expected_len = evaluator.resources.read().caches.len();
6398 Arc::get_mut(&mut evaluator.dataset)
6399 .expect("evaluator should own dataset Arc in this test")
6400 .clear_events_local();
6401 let cached = evaluator
6402 .evaluate_local(&[1.25, -0.75])
6403 .expect("evaluation should succeed");
6404 assert_eq!(cached.len(), expected_len);
6405 }
6406
6407 #[test]
6408 fn test_evaluate_gradient_local_does_not_depend_on_dataset_rows() {
6409 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
6410 .unwrap()
6411 .norm_sqr();
6412 let mut event1 = test_event();
6413 event1.p4s[0].t = 7.5;
6414 let mut event2 = test_event();
6415 event2.p4s[0].t = 8.25;
6416 let mut event3 = test_event();
6417 event3.p4s[0].t = 9.0;
6418 let dataset = Arc::new(Dataset::new_with_metadata(
6419 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
6420 Arc::new(DatasetMetadata::default()),
6421 ));
6422 let mut evaluator = expr.load(&dataset).unwrap();
6423 drop(dataset);
6424 let expected_len = evaluator.resources.read().caches.len();
6425 Arc::get_mut(&mut evaluator.dataset)
6426 .expect("evaluator should own dataset Arc in this test")
6427 .clear_events_local();
6428 let cached = evaluator
6429 .evaluate_gradient_local(&[1.25, -0.75])
6430 .expect("evaluation should succeed");
6431 assert_eq!(cached.len(), expected_len);
6432 }
6433
6434 #[test]
6435 fn test_evaluate_with_gradient_local_matches_separate_paths() {
6436 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
6437 .unwrap()
6438 .norm_sqr();
6439 let dataset = Arc::new(Dataset::new(vec![
6440 Arc::new(test_event()),
6441 Arc::new(test_event()),
6442 Arc::new(test_event()),
6443 ]));
6444 let evaluator = expr.load(&dataset).unwrap();
6445 let params = [1.25, -0.75];
6446 let values = evaluator
6447 .evaluate_local(¶ms)
6448 .expect("evaluation should succeed");
6449 let gradients = evaluator
6450 .evaluate_gradient_local(¶ms)
6451 .expect("evaluation should succeed");
6452 let fused = evaluator
6453 .evaluate_with_gradient_local(¶ms)
6454 .expect("evaluation should succeed");
6455 assert_eq!(fused.len(), values.len());
6456 assert_eq!(fused.len(), gradients.len());
6457 for ((value_gradient, value), gradient) in
6458 fused.iter().zip(values.iter()).zip(gradients.iter())
6459 {
6460 let (fused_value, fused_gradient) = value_gradient;
6461 assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
6462 assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
6463 assert_eq!(fused_gradient.len(), gradient.len());
6464 for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
6465 assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
6466 assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
6467 }
6468 }
6469 }
6470
6471 #[test]
6472 fn test_evaluate_with_gradient_batch_local_matches_separate_paths() {
6473 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
6474 .unwrap()
6475 .norm_sqr();
6476 let dataset = Arc::new(Dataset::new(vec![
6477 Arc::new(test_event()),
6478 Arc::new(test_event()),
6479 Arc::new(test_event()),
6480 Arc::new(test_event()),
6481 ]));
6482 let evaluator = expr.load(&dataset).unwrap();
6483 let params = [0.5, -1.25];
6484 let indices = vec![0, 2, 3];
6485 let values = evaluator
6486 .evaluate_batch_local(¶ms, &indices)
6487 .expect("evaluation should succeed");
6488 let gradients = evaluator
6489 .evaluate_gradient_batch_local(¶ms, &indices)
6490 .expect("evaluation should succeed");
6491 let fused = evaluator
6492 .evaluate_with_gradient_batch_local(¶ms, &indices)
6493 .expect("evaluation should succeed");
6494 assert_eq!(fused.len(), values.len());
6495 assert_eq!(fused.len(), gradients.len());
6496 for ((value_gradient, value), gradient) in
6497 fused.iter().zip(values.iter()).zip(gradients.iter())
6498 {
6499 let (fused_value, fused_gradient) = value_gradient;
6500 assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
6501 assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
6502 assert_eq!(fused_gradient.len(), gradient.len());
6503 for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
6504 assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
6505 assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
6506 }
6507 }
6508 }
6509
6510 #[test]
6511 fn test_precompute_all_columnar_populates_cache() {
6512 let mut event1 = test_event();
6513 event1.p4s[0].t = 7.5;
6514 let mut event2 = test_event();
6515 event2.p4s[0].t = 8.25;
6516 let mut event3 = test_event();
6517 event3.p4s[0].t = 9.0;
6518 let dataset = Dataset::new_with_metadata(
6519 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
6520 Arc::new(DatasetMetadata::default()),
6521 );
6522 let mut amplitude = TestAmplitude {
6523 tags: Tags::new(["test"]),
6524 re: parameter!("real"),
6525 pid_re: ParameterID::default(),
6526 im: parameter!("imag"),
6527 pid_im: ParameterID::default(),
6528 beam_energy: Default::default(),
6529 };
6530 let mut resources = Resources::default();
6531 amplitude
6532 .register(&mut resources)
6533 .expect("test amplitude should register");
6534 resources.reserve_cache(dataset.n_events());
6535 amplitude.precompute_all(&dataset, &mut resources);
6536 for cache in &resources.caches {
6537 assert!(cache.get_scalar(amplitude.beam_energy) > 0.0);
6538 }
6539 }
6540
6541 #[cfg(feature = "mpi")]
6542 #[mpi_test(np = [2])]
6543 fn test_load_reserves_local_cache_size_in_mpi() {
6544 use crate::mpi::{finalize_mpi, get_world, use_mpi};
6545
6546 use_mpi(true);
6547 assert!(get_world().is_some(), "MPI world should be initialized");
6548
6549 let expr = ComplexScalar::new(
6550 "constant",
6551 parameter!("const_re", 2.0),
6552 parameter!("const_im", 3.0),
6553 )
6554 .expect("constant amplitude should construct");
6555 let events = vec![
6556 Arc::new(test_event()),
6557 Arc::new(test_event()),
6558 Arc::new(test_event()),
6559 Arc::new(test_event()),
6560 ];
6561 let dataset = Arc::new(Dataset::new_with_metadata(
6562 events,
6563 Arc::new(DatasetMetadata::default()),
6564 ));
6565 let evaluator = expr.load(&dataset).expect("evaluator should load");
6566 let local_events = dataset.n_events_local();
6567 let cache_len = evaluator.resources.read().caches.len();
6568
6569 assert_eq!(
6570 cache_len, local_events,
6571 "cache length must match local event count under MPI"
6572 );
6573 finalize_mpi();
6574 }
6575
6576 #[cfg(feature = "mpi")]
6577 #[mpi_test(np = [2])]
6578 fn test_expression_ir_cached_integrals_are_rank_local_in_mpi() {
6579 use mpi::{collective::SystemOperation, topology::Communicator, traits::*};
6580
6581 use crate::mpi::{finalize_mpi, get_world, use_mpi};
6582
6583 use_mpi(true);
6584 let world = get_world().expect("MPI world should be initialized");
6585
6586 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6587 * &CacheOnlyScalar::new("k").unwrap();
6588 let events = vec![
6589 Arc::new(EventData {
6590 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
6591 aux: vec![],
6592 weight: 0.5,
6593 }),
6594 Arc::new(EventData {
6595 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6596 aux: vec![],
6597 weight: 1.0,
6598 }),
6599 Arc::new(EventData {
6600 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6601 aux: vec![],
6602 weight: 1.5,
6603 }),
6604 Arc::new(EventData {
6605 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
6606 aux: vec![],
6607 weight: 2.0,
6608 }),
6609 Arc::new(EventData {
6610 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6611 aux: vec![],
6612 weight: 2.5,
6613 }),
6614 Arc::new(EventData {
6615 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
6616 aux: vec![],
6617 weight: 3.0,
6618 }),
6619 Arc::new(EventData {
6620 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
6621 aux: vec![],
6622 weight: 3.5,
6623 }),
6624 Arc::new(EventData {
6625 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
6626 aux: vec![],
6627 weight: 4.0,
6628 }),
6629 ];
6630 let dataset = Arc::new(Dataset::new_with_metadata(
6631 events,
6632 Arc::new(DatasetMetadata::default()),
6633 ));
6634 let evaluator = expr.load(&dataset).expect("evaluator should load");
6635 let cached_integrals = evaluator
6636 .expression_precomputed_cached_integrals()
6637 .expect("integrals should exist");
6638 assert_eq!(cached_integrals.len(), 1);
6639
6640 let local_expected =
6641 dataset
6642 .weights_local()
6643 .iter()
6644 .enumerate()
6645 .fold(0.0, |acc, (index, weight)| {
6646 let event = dataset.event_local(index).expect("event should exist");
6647 acc + *weight * event.p4_at(0).e()
6648 });
6649 let cached_local = cached_integrals[0].weighted_cache_sum;
6650 assert_relative_eq!(cached_local.re, local_expected, epsilon = 1e-12);
6651 assert_relative_eq!(cached_local.im, 0.0, epsilon = 1e-12);
6652
6653 let weighted_value_sum = evaluator
6654 .evaluate_weighted_value_sum_local(&[2.0])
6655 .expect("evaluate should succeed");
6656 assert_relative_eq!(weighted_value_sum, 2.0 * local_expected, epsilon = 1e-10);
6657
6658 let mut global_expected = 0.0;
6659 world.all_reduce_into(
6660 &local_expected,
6661 &mut global_expected,
6662 SystemOperation::sum(),
6663 );
6664 if world.size() > 1 {
6665 assert!(
6666 (cached_local.re - global_expected).abs() > 1e-12,
6667 "cached integral should remain rank-local before MPI reduction"
6668 );
6669 }
6670 finalize_mpi();
6671 }
6672
6673 #[cfg(feature = "mpi")]
6674 #[mpi_test(np = [2])]
6675 fn test_expression_ir_weighted_sum_mpi_matches_global_eventwise_baseline() {
6676 use mpi::{collective::SystemOperation, traits::*};
6677
6678 use crate::mpi::{finalize_mpi, get_world, use_mpi};
6679
6680 use_mpi(true);
6681 let world = get_world().expect("MPI world should be initialized");
6682
6683 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6684 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6685 let c1 = CacheOnlyScalar::new("c1").unwrap();
6686 let c2 = CacheOnlyScalar::new("c2").unwrap();
6687 let c3 = CacheOnlyScalar::new("c3").unwrap();
6688 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6689 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6690 let events = vec![
6691 Arc::new(EventData {
6692 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
6693 aux: vec![],
6694 weight: 0.5,
6695 }),
6696 Arc::new(EventData {
6697 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6698 aux: vec![],
6699 weight: -1.25,
6700 }),
6701 Arc::new(EventData {
6702 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6703 aux: vec![],
6704 weight: 0.75,
6705 }),
6706 Arc::new(EventData {
6707 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
6708 aux: vec![],
6709 weight: 1.5,
6710 }),
6711 Arc::new(EventData {
6712 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6713 aux: vec![],
6714 weight: 2.25,
6715 }),
6716 Arc::new(EventData {
6717 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
6718 aux: vec![],
6719 weight: -0.5,
6720 }),
6721 Arc::new(EventData {
6722 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
6723 aux: vec![],
6724 weight: 3.5,
6725 }),
6726 Arc::new(EventData {
6727 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
6728 aux: vec![],
6729 weight: 1.25,
6730 }),
6731 ];
6732 let dataset = Arc::new(Dataset::new_with_metadata(
6733 events,
6734 Arc::new(DatasetMetadata::default()),
6735 ));
6736 let evaluator = expr.load(&dataset).expect("evaluator should load");
6737 let params = vec![0.2, -0.3, 1.1, -0.7];
6738
6739 let local_expected_value = evaluator
6740 .evaluate_local(¶ms)
6741 .expect("evaluate should succeed")
6742 .iter()
6743 .zip(dataset.weights_local().iter())
6744 .fold(0.0, |accum, (value, event)| accum + *event * value.re);
6745 let mut global_expected_value = 0.0;
6746 world.all_reduce_into(
6747 &local_expected_value,
6748 &mut global_expected_value,
6749 SystemOperation::sum(),
6750 );
6751 let mpi_value = evaluator
6752 .evaluate_weighted_value_sum_mpi(¶ms, &world)
6753 .expect("evaluate should succeed");
6754 assert_relative_eq!(mpi_value, global_expected_value, epsilon = 1e-10);
6755
6756 let local_expected_gradient = evaluator
6757 .evaluate_gradient_local(¶ms)
6758 .expect("evaluate should succeed")
6759 .iter()
6760 .zip(dataset.weights_local().iter())
6761 .fold(
6762 DVector::zeros(params.len()),
6763 |mut accum, (gradient, event)| {
6764 accum += gradient.map(|value| value.re).scale(*event);
6765 accum
6766 },
6767 );
6768 let mut global_expected_gradient = vec![0.0; local_expected_gradient.len()];
6769 world.all_reduce_into(
6770 local_expected_gradient.as_slice(),
6771 &mut global_expected_gradient,
6772 SystemOperation::sum(),
6773 );
6774 let mpi_gradient = evaluator
6775 .evaluate_weighted_gradient_sum_mpi(¶ms, &world)
6776 .expect("evaluate should succeed");
6777 for (actual, expected) in mpi_gradient.iter().zip(global_expected_gradient.iter()) {
6778 assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
6779 }
6780
6781 finalize_mpi();
6782 }
6783
6784 #[test]
6785 fn test_evaluate_local_succeeds_for_constant_amplitude() {
6786 let expr = ComplexScalar::new(
6787 "constant",
6788 parameter!("const_re", 2.0),
6789 parameter!("const_im", 3.0),
6790 )
6791 .unwrap();
6792 let dataset = Arc::new(Dataset::new_with_metadata(
6793 vec![Arc::new(test_event())],
6794 Arc::new(DatasetMetadata::default()),
6795 ));
6796 let evaluator = expr.load(&dataset).unwrap();
6797 let values = evaluator
6798 .evaluate_local(&[])
6799 .expect("evaluation should succeed");
6800 assert_eq!(values.len(), 1);
6801 let gradients = evaluator
6802 .evaluate_gradient_local(&[])
6803 .expect("evaluation should succeed");
6804 assert_eq!(gradients.len(), 1);
6805 }
6806
6807 #[test]
6808 fn test_constant_amplitude() {
6809 let expr = ComplexScalar::new(
6810 "constant",
6811 parameter!("const_re", 2.0),
6812 parameter!("const_im", 3.0),
6813 )
6814 .unwrap();
6815 let dataset = Arc::new(Dataset::new_with_metadata(
6816 vec![Arc::new(test_event())],
6817 Arc::new(DatasetMetadata::default()),
6818 ));
6819 let evaluator = expr.load(&dataset).unwrap();
6820 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
6821 assert_eq!(result[0], Complex64::new(2.0, 3.0));
6822 }
6823
6824 #[test]
6825 fn test_parametric_amplitude() {
6826 let expr = ComplexScalar::new(
6827 "parametric",
6828 parameter!("test_param_re"),
6829 parameter!("test_param_im"),
6830 )
6831 .unwrap();
6832 let dataset = Arc::new(test_dataset());
6833 let evaluator = expr.load(&dataset).unwrap();
6834 let result = evaluator
6835 .evaluate(&[2.0, 3.0])
6836 .expect("evaluation should succeed");
6837 assert_eq!(result[0], Complex64::new(2.0, 3.0));
6838 }
6839
6840 #[test]
6841 fn test_expression_operations() {
6842 let expr1 = ComplexScalar::new(
6843 "const1",
6844 parameter!("const1_re", 2.0),
6845 parameter!("const1_im", 0.0),
6846 )
6847 .unwrap();
6848 let expr2 = ComplexScalar::new(
6849 "const2",
6850 parameter!("const2_re", 0.0),
6851 parameter!("const2_im", 1.0),
6852 )
6853 .unwrap();
6854 let expr3 = ComplexScalar::new(
6855 "const3",
6856 parameter!("const3_re", 3.0),
6857 parameter!("const3_im", 4.0),
6858 )
6859 .unwrap();
6860
6861 let dataset = Arc::new(test_dataset());
6862
6863 let expr_add = &expr1 + &expr2;
6865 let result_add = expr_add
6866 .load(&dataset)
6867 .unwrap()
6868 .evaluate(&[])
6869 .expect("evaluation should succeed");
6870 assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
6871
6872 let expr_sub = &expr1 - &expr2;
6874 let result_sub = expr_sub
6875 .load(&dataset)
6876 .unwrap()
6877 .evaluate(&[])
6878 .expect("evaluation should succeed");
6879 assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
6880
6881 let expr_mul = &expr1 * &expr2;
6883 let result_mul = expr_mul
6884 .load(&dataset)
6885 .unwrap()
6886 .evaluate(&[])
6887 .expect("evaluation should succeed");
6888 assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
6889
6890 let expr_div = &expr1 / &expr3;
6892 let result_div = expr_div
6893 .load(&dataset)
6894 .unwrap()
6895 .evaluate(&[])
6896 .expect("evaluation should succeed");
6897 assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
6898
6899 let expr_neg = -&expr3;
6901 let result_neg = expr_neg
6902 .load(&dataset)
6903 .unwrap()
6904 .evaluate(&[])
6905 .expect("evaluation should succeed");
6906 assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
6907
6908 let expr_add2 = &expr_add + &expr_mul;
6910 let result_add2 = expr_add2
6911 .load(&dataset)
6912 .unwrap()
6913 .evaluate(&[])
6914 .expect("evaluation should succeed");
6915 assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
6916
6917 let expr_sub2 = &expr_add - &expr_mul;
6919 let result_sub2 = expr_sub2
6920 .load(&dataset)
6921 .unwrap()
6922 .evaluate(&[])
6923 .expect("evaluation should succeed");
6924 assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
6925
6926 let expr_mul2 = &expr_add * &expr_mul;
6928 let result_mul2 = expr_mul2
6929 .load(&dataset)
6930 .unwrap()
6931 .evaluate(&[])
6932 .expect("evaluation should succeed");
6933 assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
6934
6935 let expr_div2 = &expr_add / &expr_add2;
6937 let result_div2 = expr_div2
6938 .load(&dataset)
6939 .unwrap()
6940 .evaluate(&[])
6941 .expect("evaluation should succeed");
6942 assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
6943
6944 let expr_neg2 = -&expr_mul2;
6946 let result_neg2 = expr_neg2
6947 .load(&dataset)
6948 .unwrap()
6949 .evaluate(&[])
6950 .expect("evaluation should succeed");
6951 assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
6952
6953 let expr_real = expr3.real();
6955 let result_real = expr_real
6956 .load(&dataset)
6957 .unwrap()
6958 .evaluate(&[])
6959 .expect("evaluation should succeed");
6960 assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
6961
6962 let expr_mul2_real = expr_mul2.real();
6964 let result_mul2_real = expr_mul2_real
6965 .load(&dataset)
6966 .unwrap()
6967 .evaluate(&[])
6968 .expect("evaluation should succeed");
6969 assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
6970
6971 let expr_imag = expr3.imag();
6973 let result_imag = expr_imag
6974 .load(&dataset)
6975 .unwrap()
6976 .evaluate(&[])
6977 .expect("evaluation should succeed");
6978 assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
6979
6980 let expr_mul2_imag = expr_mul2.imag();
6982 let result_mul2_imag = expr_mul2_imag
6983 .load(&dataset)
6984 .unwrap()
6985 .evaluate(&[])
6986 .expect("evaluation should succeed");
6987 assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
6988
6989 let expr_conj = expr3.conj();
6991 let result_conj = expr_conj
6992 .load(&dataset)
6993 .unwrap()
6994 .evaluate(&[])
6995 .expect("evaluation should succeed");
6996 assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
6997
6998 let expr_mul2_conj = expr_mul2.conj();
7000 let result_mul2_conj = expr_mul2_conj
7001 .load(&dataset)
7002 .unwrap()
7003 .evaluate(&[])
7004 .expect("evaluation should succeed");
7005 assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
7006
7007 let expr_norm = expr1.norm_sqr();
7009 let result_norm = expr_norm
7010 .load(&dataset)
7011 .unwrap()
7012 .evaluate(&[])
7013 .expect("evaluation should succeed");
7014 assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
7015
7016 let expr_mul2_norm = expr_mul2.norm_sqr();
7018 let result_mul2_norm = expr_mul2_norm
7019 .load(&dataset)
7020 .unwrap()
7021 .evaluate(&[])
7022 .expect("evaluation should succeed");
7023 assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
7024 }
7025
7026 #[test]
7027 fn test_amplitude_activation() {
7028 let expr1 = ComplexScalar::new(
7029 "const1",
7030 parameter!("const1_re_act", 1.0),
7031 parameter!("const1_im_act", 0.0),
7032 )
7033 .unwrap();
7034 let expr2 = ComplexScalar::new(
7035 "const2",
7036 parameter!("const2_re_act", 2.0),
7037 parameter!("const2_im_act", 0.0),
7038 )
7039 .unwrap();
7040
7041 let dataset = Arc::new(test_dataset());
7042 let expr = &expr1 + &expr2;
7043 let evaluator = expr.load(&dataset).unwrap();
7044
7045 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7047 assert_eq!(result[0], Complex64::new(3.0, 0.0));
7048
7049 evaluator.deactivate_strict("const1").unwrap();
7051 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7052 assert_eq!(result[0], Complex64::new(2.0, 0.0));
7053
7054 evaluator.isolate_strict("const1").unwrap();
7056 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7057 assert_eq!(result[0], Complex64::new(1.0, 0.0));
7058
7059 evaluator.activate_all();
7061 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7062 assert_eq!(result[0], Complex64::new(3.0, 0.0));
7063 }
7064
7065 #[test]
7066 fn test_gradient() {
7067 let expr1 = ComplexScalar::new(
7068 "parametric_1",
7069 parameter!("test_param_re_1"),
7070 parameter!("test_param_im_1"),
7071 )
7072 .unwrap();
7073 let expr2 = ComplexScalar::new(
7074 "parametric_2",
7075 parameter!("test_param_re_2"),
7076 parameter!("test_param_im_2"),
7077 )
7078 .unwrap();
7079
7080 let dataset = Arc::new(test_dataset());
7081 let params = vec![2.0, 3.0, 4.0, 5.0];
7082
7083 let expr = &expr1 + &expr2;
7084 let evaluator = expr.load(&dataset).unwrap();
7085
7086 let gradient = evaluator
7087 .evaluate_gradient(¶ms)
7088 .expect("evaluation should succeed");
7089
7090 assert_relative_eq!(gradient[0][0].re, 1.0);
7091 assert_relative_eq!(gradient[0][0].im, 0.0);
7092 assert_relative_eq!(gradient[0][1].re, 0.0);
7093 assert_relative_eq!(gradient[0][1].im, 1.0);
7094 assert_relative_eq!(gradient[0][2].re, 1.0);
7095 assert_relative_eq!(gradient[0][2].im, 0.0);
7096 assert_relative_eq!(gradient[0][3].re, 0.0);
7097 assert_relative_eq!(gradient[0][3].im, 1.0);
7098
7099 let expr = &expr1 - &expr2;
7100 let evaluator = expr.load(&dataset).unwrap();
7101
7102 let gradient = evaluator
7103 .evaluate_gradient(¶ms)
7104 .expect("evaluation should succeed");
7105
7106 assert_relative_eq!(gradient[0][0].re, 1.0);
7107 assert_relative_eq!(gradient[0][0].im, 0.0);
7108 assert_relative_eq!(gradient[0][1].re, 0.0);
7109 assert_relative_eq!(gradient[0][1].im, 1.0);
7110 assert_relative_eq!(gradient[0][2].re, -1.0);
7111 assert_relative_eq!(gradient[0][2].im, 0.0);
7112 assert_relative_eq!(gradient[0][3].re, 0.0);
7113 assert_relative_eq!(gradient[0][3].im, -1.0);
7114
7115 let expr = &expr1 * &expr2;
7116 let evaluator = expr.load(&dataset).unwrap();
7117
7118 let gradient = evaluator
7119 .evaluate_gradient(¶ms)
7120 .expect("evaluation should succeed");
7121
7122 assert_relative_eq!(gradient[0][0].re, 4.0);
7123 assert_relative_eq!(gradient[0][0].im, 5.0);
7124 assert_relative_eq!(gradient[0][1].re, -5.0);
7125 assert_relative_eq!(gradient[0][1].im, 4.0);
7126 assert_relative_eq!(gradient[0][2].re, 2.0);
7127 assert_relative_eq!(gradient[0][2].im, 3.0);
7128 assert_relative_eq!(gradient[0][3].re, -3.0);
7129 assert_relative_eq!(gradient[0][3].im, 2.0);
7130
7131 let expr = &expr1 / &expr2;
7132 let evaluator = expr.load(&dataset).unwrap();
7133
7134 let gradient = evaluator
7135 .evaluate_gradient(¶ms)
7136 .expect("evaluation should succeed");
7137
7138 assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
7139 assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
7140 assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
7141 assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
7142 assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
7143 assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
7144 assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
7145 assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
7146
7147 let expr = -(&expr1 * &expr2);
7148 let evaluator = expr.load(&dataset).unwrap();
7149
7150 let gradient = evaluator
7151 .evaluate_gradient(¶ms)
7152 .expect("evaluation should succeed");
7153
7154 assert_relative_eq!(gradient[0][0].re, -4.0);
7155 assert_relative_eq!(gradient[0][0].im, -5.0);
7156 assert_relative_eq!(gradient[0][1].re, 5.0);
7157 assert_relative_eq!(gradient[0][1].im, -4.0);
7158 assert_relative_eq!(gradient[0][2].re, -2.0);
7159 assert_relative_eq!(gradient[0][2].im, -3.0);
7160 assert_relative_eq!(gradient[0][3].re, 3.0);
7161 assert_relative_eq!(gradient[0][3].im, -2.0);
7162
7163 let expr = (&expr1 * &expr2).real();
7164 let evaluator = expr.load(&dataset).unwrap();
7165
7166 let gradient = evaluator
7167 .evaluate_gradient(¶ms)
7168 .expect("evaluation should succeed");
7169
7170 assert_relative_eq!(gradient[0][0].re, 4.0);
7171 assert_relative_eq!(gradient[0][0].im, 0.0);
7172 assert_relative_eq!(gradient[0][1].re, -5.0);
7173 assert_relative_eq!(gradient[0][1].im, 0.0);
7174 assert_relative_eq!(gradient[0][2].re, 2.0);
7175 assert_relative_eq!(gradient[0][2].im, 0.0);
7176 assert_relative_eq!(gradient[0][3].re, -3.0);
7177 assert_relative_eq!(gradient[0][3].im, 0.0);
7178
7179 let expr = (&expr1 * &expr2).imag();
7180 let evaluator = expr.load(&dataset).unwrap();
7181
7182 let gradient = evaluator
7183 .evaluate_gradient(¶ms)
7184 .expect("evaluation should succeed");
7185
7186 assert_relative_eq!(gradient[0][0].re, 5.0);
7187 assert_relative_eq!(gradient[0][0].im, 0.0);
7188 assert_relative_eq!(gradient[0][1].re, 4.0);
7189 assert_relative_eq!(gradient[0][1].im, 0.0);
7190 assert_relative_eq!(gradient[0][2].re, 3.0);
7191 assert_relative_eq!(gradient[0][2].im, 0.0);
7192 assert_relative_eq!(gradient[0][3].re, 2.0);
7193 assert_relative_eq!(gradient[0][3].im, 0.0);
7194
7195 let expr = (&expr1 * &expr2).conj();
7196 let evaluator = expr.load(&dataset).unwrap();
7197
7198 let gradient = evaluator
7199 .evaluate_gradient(¶ms)
7200 .expect("evaluation should succeed");
7201
7202 assert_relative_eq!(gradient[0][0].re, 4.0);
7203 assert_relative_eq!(gradient[0][0].im, -5.0);
7204 assert_relative_eq!(gradient[0][1].re, -5.0);
7205 assert_relative_eq!(gradient[0][1].im, -4.0);
7206 assert_relative_eq!(gradient[0][2].re, 2.0);
7207 assert_relative_eq!(gradient[0][2].im, -3.0);
7208 assert_relative_eq!(gradient[0][3].re, -3.0);
7209 assert_relative_eq!(gradient[0][3].im, -2.0);
7210
7211 let expr = (&expr1 * &expr2).norm_sqr();
7212 let evaluator = expr.load(&dataset).unwrap();
7213
7214 let gradient = evaluator
7215 .evaluate_gradient(¶ms)
7216 .expect("evaluation should succeed");
7217
7218 assert_relative_eq!(gradient[0][0].re, 164.0);
7219 assert_relative_eq!(gradient[0][0].im, 0.0);
7220 assert_relative_eq!(gradient[0][1].re, 246.0);
7221 assert_relative_eq!(gradient[0][1].im, 0.0);
7222 assert_relative_eq!(gradient[0][2].re, 104.0);
7223 assert_relative_eq!(gradient[0][2].im, 0.0);
7224 assert_relative_eq!(gradient[0][3].re, 130.0);
7225 assert_relative_eq!(gradient[0][3].im, 0.0);
7226 }
7227
7228 #[test]
7229 fn test_expression_function_gradients() {
7230 let expr1 = ComplexScalar::new(
7231 "function_parametric_1",
7232 parameter!("function_test_param_re_1"),
7233 parameter!("function_test_param_im_1"),
7234 )
7235 .unwrap();
7236 let expr2 = ComplexScalar::new(
7237 "function_parametric_2",
7238 parameter!("function_test_param_re_2"),
7239 parameter!("function_test_param_im_2"),
7240 )
7241 .unwrap();
7242
7243 let sin = expr1.sin();
7244 let cos = expr1.cos();
7245 let trig = &sin * &cos;
7246 let pow = expr1.pow(&expr2);
7247 let mut expr = expr1.sqrt();
7248 expr = &expr + &expr1.exp();
7249 expr = &expr + &expr1.powi(2);
7250 expr = &expr + &expr1.powf(1.7);
7251 expr = &expr + &trig;
7252 expr = &expr + &expr1.log();
7253 expr = &expr + &expr1.cis();
7254 expr = &expr + &pow;
7255
7256 let dataset = Arc::new(test_dataset());
7257 let evaluator = expr.load(&dataset).unwrap();
7258 let params = vec![2.0, 0.5, 1.2, -0.3];
7259 let gradient = evaluator
7260 .evaluate_gradient(¶ms)
7261 .expect("evaluation should succeed");
7262 let eps = 1e-6;
7263
7264 for param_index in 0..params.len() {
7265 let mut plus = params.clone();
7266 plus[param_index] += eps;
7267 let mut minus = params.clone();
7268 minus[param_index] -= eps;
7269 let finite_diff = (evaluator
7270 .evaluate(&plus)
7271 .expect("evaluation should succeed")[0]
7272 - evaluator
7273 .evaluate(&minus)
7274 .expect("evaluation should succeed")[0])
7275 / Complex64::new(2.0 * eps, 0.0);
7276
7277 assert_relative_eq!(
7278 gradient[0][param_index].re,
7279 finite_diff.re,
7280 epsilon = 1e-6,
7281 max_relative = 1e-6
7282 );
7283 assert_relative_eq!(
7284 gradient[0][param_index].im,
7285 finite_diff.im,
7286 epsilon = 1e-6,
7287 max_relative = 1e-6
7288 );
7289 }
7290 }
7291
7292 #[test]
7293 fn test_zeros_and_ones() {
7294 let amp = ComplexScalar::new(
7295 "parametric",
7296 parameter!("test_param_re"),
7297 parameter!("fixed_two", 2.0),
7298 )
7299 .unwrap();
7300 let dataset = Arc::new(test_dataset());
7301 let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
7302 let evaluator = expr.load(&dataset).unwrap();
7303
7304 let params = vec![2.0];
7305 let value = evaluator
7306 .evaluate(¶ms)
7307 .expect("evaluation should succeed");
7308 let gradient = evaluator
7309 .evaluate_gradient(¶ms)
7310 .expect("evaluation should succeed");
7311
7312 assert_relative_eq!(value[0].re, 8.0);
7314 assert_relative_eq!(value[0].im, 0.0);
7315
7316 assert_relative_eq!(gradient[0][0].re, 4.0);
7318 assert_relative_eq!(gradient[0][0].im, 0.0);
7319 }
7320 #[test]
7321 fn test_default_build_uses_lowered_expression_runtime() {
7322 let expr = ComplexScalar::new(
7323 "opt_in_gate",
7324 parameter!("opt_in_gate_re", 2.0),
7325 parameter!("opt_in_gate_im", 0.0),
7326 )
7327 .unwrap()
7328 .norm_sqr();
7329 let dataset = Arc::new(test_dataset());
7330 let evaluator = expr.load(&dataset).unwrap();
7331
7332 let diagnostics = evaluator.expression_runtime_diagnostics();
7333 assert!(diagnostics.ir_planning_enabled);
7334 assert!(diagnostics.lowered_value_program_present);
7335 assert!(diagnostics.lowered_gradient_program_present);
7336 assert!(diagnostics.lowered_value_gradient_program_present);
7337 assert_eq!(
7338 evaluator.evaluate(&[]).expect("evaluation should succeed")[0],
7339 Complex64::new(4.0, 0.0)
7340 );
7341 }
7342
7343 #[test]
7344 fn parameter_name_only_creates_free_parameter() {
7345 let p = parameter!("mass");
7346
7347 assert_eq!(p.name(), "mass");
7348 assert_eq!(p.fixed(), None);
7349 assert_eq!(p.initial(), None);
7350 assert_eq!(p.bounds(), (None, None));
7351 assert_eq!(p.unit(), None);
7352 assert_eq!(p.latex(), None);
7353 assert_eq!(p.description(), None);
7354 assert!(p.is_free());
7355 assert!(!p.is_fixed());
7356 }
7357
7358 #[test]
7359 fn parameter_name_and_value_creates_fixed_parameter() {
7360 let p = parameter!("width", 0.15);
7361
7362 assert_eq!(p.name(), "width");
7363 assert_eq!(p.fixed(), Some(0.15));
7364 assert_eq!(p.initial(), Some(0.15));
7365 assert!(p.is_fixed());
7366 assert!(!p.is_free());
7367 }
7368
7369 #[test]
7370 fn keyword_initial_sets_initial_only() {
7371 let p = parameter!("alpha", initial: 1.25);
7372
7373 assert_eq!(p.name(), "alpha");
7374 assert_eq!(p.fixed(), None);
7375 assert_eq!(p.initial(), Some(1.25));
7376 assert_eq!(p.bounds(), (None, None));
7377 assert!(p.is_free());
7378 }
7379
7380 #[test]
7381 fn keyword_fixed_sets_fixed_and_initial() {
7382 let p = parameter!("beta", fixed: 2.5);
7383
7384 assert_eq!(p.name(), "beta");
7385 assert_eq!(p.fixed(), Some(2.5));
7386 assert_eq!(p.initial(), Some(2.5));
7387 assert!(p.is_fixed());
7388 }
7389
7390 #[test]
7391 fn bounds_accept_plain_numbers() {
7392 let p = parameter!("x", bounds: (0.0, 10.0));
7393
7394 assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
7395 }
7396
7397 #[test]
7398 fn bounds_accept_none_and_number() {
7399 let p = parameter!("x", bounds: (None, 10.0));
7400
7401 assert_eq!(p.bounds(), (None, Some(10.0)));
7402 }
7403
7404 #[test]
7405 fn bounds_accept_number_and_none() {
7406 let p = parameter!("x", bounds: (-1.0, None));
7407
7408 assert_eq!(p.bounds(), (Some(-1.0), None));
7409 }
7410
7411 #[test]
7412 fn bounds_accept_both_none() {
7413 let p = parameter!("x", bounds: (None, None));
7414
7415 assert_eq!(p.bounds(), (None, None));
7416 }
7417
7418 #[test]
7419 fn bounds_accept_arbitrary_expressions() {
7420 let lo = 1.0;
7421 let hi = 2.0 * 3.0;
7422 let p = parameter!("x", bounds: (lo - 0.5, hi));
7423
7424 assert_eq!(p.bounds(), (Some(0.5), Some(6.0)));
7425 }
7426
7427 #[test]
7428 fn multiple_keyword_arguments_work_together() {
7429 let p = parameter!(
7430 "gamma",
7431 initial: 1.0,
7432 bounds: (0.0, 5.0),
7433 unit: "GeV",
7434 latex: r"\gamma",
7435 description: "test parameter",
7436 );
7437
7438 assert_eq!(p.name(), "gamma");
7439 assert_eq!(p.fixed(), None);
7440 assert_eq!(p.initial(), Some(1.0));
7441 assert_eq!(p.bounds(), (Some(0.0), Some(5.0)));
7442 assert_eq!(p.unit().as_deref(), Some("GeV"));
7443 assert_eq!(p.latex().as_deref(), Some(r"\gamma"));
7444 assert_eq!(p.description().as_deref(), Some("test parameter"));
7445 }
7446
7447 #[test]
7448 fn fixed_can_be_combined_with_other_fields() {
7449 let p = parameter!(
7450 "delta",
7451 fixed: 3.0,
7452 bounds: (0.0, 10.0),
7453 unit: "rad",
7454 );
7455
7456 assert_eq!(p.name(), "delta");
7457 assert_eq!(p.fixed(), Some(3.0));
7458 assert_eq!(p.initial(), Some(3.0));
7459 assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
7460 assert_eq!(p.unit().as_deref(), Some("rad"));
7461 }
7462
7463 #[test]
7464 fn trailing_comma_is_accepted() {
7465 let p = parameter!(
7466 "eps",
7467 initial: 0.5,
7468 bounds: (None, 1.0),
7469 unit: "arb",
7470 );
7471
7472 assert_eq!(p.initial(), Some(0.5));
7473 assert_eq!(p.bounds(), (None, Some(1.0)));
7474 assert_eq!(p.unit().as_deref(), Some("arb"));
7475 }
7476
7477 #[test]
7478 fn test_parameter_registration() {
7479 let expr = ComplexScalar::new(
7480 "parametric",
7481 parameter!("test_param_re"),
7482 parameter!("fixed_two", 2.0),
7483 )
7484 .unwrap();
7485 let parameters = expr.parameters().free().names();
7486 assert_eq!(parameters.len(), 1);
7487 assert_eq!(parameters[0], "test_param_re");
7488 }
7489
7490 #[test]
7491 fn test_duplicate_amplitude_tag_registration_is_allowed() {
7492 let amp1 = ComplexScalar::new(
7493 "same_name",
7494 parameter!("dup_re1", 1.0),
7495 parameter!("dup_im1", 0.0),
7496 )
7497 .unwrap();
7498 let amp2 = ComplexScalar::new(
7499 "same_name",
7500 parameter!("dup_re2", 2.0),
7501 parameter!("dup_im2", 0.0),
7502 )
7503 .unwrap();
7504 let expr = amp1 + amp2;
7505 assert_eq!(
7506 expr.parameters().fixed().names(),
7507 vec!["dup_re1", "dup_im1", "dup_re2", "dup_im2"]
7508 );
7509 }
7510
7511 #[test]
7512 fn test_tree_printing() {
7513 let amp1 = ComplexScalar::new(
7514 "parametric_1",
7515 parameter!("test_param_re_1"),
7516 parameter!("test_param_im_1"),
7517 )
7518 .unwrap();
7519 let amp2 = ComplexScalar::new(
7520 "parametric_2",
7521 parameter!("test_param_re_2"),
7522 parameter!("test_param_im_2"),
7523 )
7524 .unwrap();
7525 let expr =
7526 &1.real() + &2.conj().imag() + Expression::one() * Complex64::new(-1.4, 2.0)
7527 - Expression::zero() / 1.0
7528 + (&1 * &2).norm_sqr();
7529 assert_eq!(
7530 expr.to_string(),
7531 concat!(
7532 "+\n",
7533 "├─ -\n",
7534 "│ ├─ +\n",
7535 "│ │ ├─ +\n",
7536 "│ │ │ ├─ Re\n",
7537 "│ │ │ │ └─ parametric_1(id=0)\n",
7538 "│ │ │ └─ Im\n",
7539 "│ │ │ └─ *\n",
7540 "│ │ │ └─ parametric_2(id=1)\n",
7541 "│ │ └─ ×\n",
7542 "│ │ ├─ 1 (exact)\n",
7543 "│ │ └─ -1.4+2i\n",
7544 "│ └─ ÷\n",
7545 "│ ├─ 0 (exact)\n",
7546 "│ └─ 1 (exact)\n",
7547 "└─ NormSqr\n",
7548 " └─ ×\n",
7549 " ├─ parametric_1(id=0)\n",
7550 " └─ parametric_2(id=1)\n",
7551 )
7552 );
7553 }
7554}