1use std::collections::HashMap;
2use std::hash::Hash;
3use std::ops::Index;
4use std::time::Instant;
5use std::{
6 fmt::{Debug, Display},
7 sync::{
8 atomic::{AtomicU64, Ordering},
9 Arc,
10 },
11};
12
13use auto_ops::*;
14use dyn_clone::DynClone;
15use nalgebra::DVector;
16use num::complex::Complex64;
17
18use parking_lot::{Mutex, RwLock};
19#[cfg(feature = "rayon")]
20use rayon::prelude::*;
21use serde::{Deserialize, Serialize};
22
23static AMPLITUDE_INSTANCE_COUNTER: AtomicU64 = AtomicU64::new(0);
24
25fn next_amplitude_id() -> u64 {
26 AMPLITUDE_INSTANCE_COUNTER.fetch_add(1, Ordering::Relaxed)
27}
28#[allow(dead_code)]
29mod ir;
30#[allow(dead_code)]
31mod lowered;
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum ExpressionDependence {
36 ParameterOnly,
38 CacheOnly,
40 Mixed,
42}
43impl From<ir::DependenceClass> for ExpressionDependence {
44 fn from(value: ir::DependenceClass) -> Self {
45 match value {
46 ir::DependenceClass::ParameterOnly => Self::ParameterOnly,
47 ir::DependenceClass::CacheOnly => Self::CacheOnly,
48 ir::DependenceClass::Mixed => Self::Mixed,
49 }
50 }
51}
52#[derive(Clone, Debug, PartialEq, Eq)]
53pub struct NormalizationPlanExplain {
55 pub root_dependence: ExpressionDependence,
57 pub warnings: Vec<String>,
59 pub separable_mul_candidate_nodes: Vec<usize>,
61 pub cached_separable_nodes: Vec<usize>,
63 pub residual_terms: Vec<usize>,
65}
66#[derive(Clone, Debug, PartialEq, Eq)]
67pub struct NormalizationExecutionSetsExplain {
69 pub cached_parameter_amplitudes: Vec<usize>,
71 pub cached_cache_amplitudes: Vec<usize>,
73 pub residual_amplitudes: Vec<usize>,
75}
76#[derive(Clone, Debug, PartialEq)]
77pub struct PrecomputedCachedIntegral {
79 pub mul_node_index: usize,
81 pub parameter_node_index: usize,
83 pub cache_node_index: usize,
85 pub coefficient: i32,
87 pub weighted_cache_sum: Complex64,
89}
90#[derive(Clone, Debug, PartialEq)]
91pub struct PrecomputedCachedIntegralGradientTerm {
93 pub mul_node_index: usize,
95 pub parameter_node_index: usize,
97 pub cache_node_index: usize,
99 pub coefficient: i32,
101 pub weighted_gradient: DVector<Complex64>,
103}
104#[derive(Clone, Debug, PartialEq, Eq, Hash)]
105struct CachedIntegralCacheKey {
106 active_mask: Vec<bool>,
107 n_events_local: usize,
108 events_local_len: usize,
109 weighted_sum_bits: u64,
110 events_ptr: usize,
111}
112#[derive(Clone, Debug)]
113struct CachedIntegralCacheState {
114 key: CachedIntegralCacheKey,
115 expression_ir: ir::ExpressionIR,
116 values: Vec<PrecomputedCachedIntegral>,
117 execution_sets: ir::NormalizationExecutionSets,
118}
119#[derive(Clone, Debug)]
120struct LoweredArtifactCacheState {
121 parameter_node_indices: Vec<usize>,
122 mul_node_indices: Vec<usize>,
123 lowered_parameter_factors: Vec<Option<lowered::LoweredFactorRuntime>>,
124 residual_runtime: Option<lowered::LoweredExpressionRuntime>,
125 lowered_runtime: lowered::LoweredExpressionRuntime,
126}
127#[derive(Clone)]
128struct ExpressionSpecializationState {
129 cached_integrals: Arc<CachedIntegralCacheState>,
130 lowered_artifacts: Arc<LoweredArtifactCacheState>,
131}
132#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
133pub struct ExpressionSpecializationMetrics {
135 pub cache_hits: usize,
137 pub cache_misses: usize,
139}
140#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
141pub struct ExpressionCompileMetrics {
143 pub initial_ir_compile_nanos: u64,
145 pub initial_cached_integrals_nanos: u64,
147 pub initial_lowering_nanos: u64,
149 pub specialization_cache_hits: usize,
151 pub specialization_cache_misses: usize,
153 pub specialization_ir_compile_nanos: u64,
155 pub specialization_cached_integrals_nanos: u64,
157 pub specialization_lowering_nanos: u64,
159 pub specialization_lowering_cache_hits: usize,
161 pub specialization_lowering_cache_misses: usize,
163 pub specialization_cache_restore_nanos: u64,
165}
166impl From<ir::NormalizationPlanExplain> for NormalizationPlanExplain {
167 fn from(value: ir::NormalizationPlanExplain) -> Self {
168 Self {
169 root_dependence: value.root_dependence.into(),
170 warnings: value.warnings,
171 separable_mul_candidate_nodes: value
172 .separable_mul_candidates
173 .into_iter()
174 .map(|candidate| candidate.node_index)
175 .collect(),
176 cached_separable_nodes: value.cached_separable_nodes,
177 residual_terms: value.residual_terms,
178 }
179 }
180}
181impl From<ir::NormalizationExecutionSets> for NormalizationExecutionSetsExplain {
182 fn from(value: ir::NormalizationExecutionSets) -> Self {
183 Self {
184 cached_parameter_amplitudes: value.cached_parameter_amplitudes,
185 cached_cache_amplitudes: value.cached_cache_amplitudes,
186 residual_amplitudes: value.residual_amplitudes,
187 }
188 }
189}
190impl From<ExpressionDependence> for ir::DependenceClass {
191 fn from(value: ExpressionDependence) -> Self {
192 match value {
193 ExpressionDependence::ParameterOnly => Self::ParameterOnly,
194 ExpressionDependence::CacheOnly => Self::CacheOnly,
195 ExpressionDependence::Mixed => Self::Mixed,
196 }
197 }
198}
199
200#[cfg(feature = "execution-context-prototype")]
201use crate::ExecutionContext;
202#[cfg(all(feature = "execution-context-prototype", feature = "rayon"))]
203use crate::ThreadPolicy;
204use crate::{
205 data::{Dataset, DatasetMetadata, NamedEventView},
206 resources::{Cache, Parameters, Resources},
207 LadduError, LadduResult, ParameterID, ReadWrite,
208};
209
210#[cfg(feature = "mpi")]
211use crate::mpi::LadduMPI;
212
213#[cfg(feature = "mpi")]
214use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
215
216#[derive(Clone, Default, Serialize, Deserialize, Debug)]
217struct ParameterMetadata {
218 name: String,
220 fixed: Option<f64>,
222 initial: Option<f64>,
225 bounds: (Option<f64>, Option<f64>),
228 unit: Option<String>,
230 latex: Option<String>,
232 description: Option<String>,
234}
235
236#[derive(Clone, Default, Serialize, Deserialize, Debug)]
238pub struct Parameter(Arc<Mutex<ParameterMetadata>>);
239
240impl PartialEq for Parameter {
242 fn eq(&self, other: &Self) -> bool {
243 self.0.lock().name == other.0.lock().name
244 }
245}
246impl Eq for Parameter {}
247impl Hash for Parameter {
248 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
249 self.0.lock().name.hash(state);
250 }
251}
252
253pub trait IntoBound {
255 fn into_bound(self) -> Option<f64>;
257}
258impl IntoBound for f64 {
259 fn into_bound(self) -> Option<f64> {
260 Some(self)
261 }
262}
263impl IntoBound for Option<f64> {
264 fn into_bound(self) -> Option<f64> {
265 self
266 }
267}
268
269impl Parameter {
270 pub fn new(name: impl Into<String>) -> Self {
272 Self(Arc::new(Mutex::new(ParameterMetadata {
273 name: name.into(),
274 ..Default::default()
275 })))
276 }
277
278 pub fn new_fixed(name: impl Into<String>, value: f64) -> Self {
280 Self(Arc::new(Mutex::new(ParameterMetadata {
281 name: name.into(),
282 fixed: Some(value),
283 ..Default::default()
284 })))
285 }
286
287 pub fn name(&self) -> String {
289 self.0.lock().name.clone()
290 }
291
292 pub fn fixed(&self) -> Option<f64> {
294 self.0.lock().fixed
295 }
296
297 pub fn initial(&self) -> Option<f64> {
299 self.0.lock().initial
300 }
301
302 pub fn bounds(&self) -> (Option<f64>, Option<f64>) {
304 self.0.lock().bounds
305 }
306
307 pub fn unit(&self) -> Option<String> {
309 self.0.lock().unit.clone()
310 }
311
312 pub fn latex(&self) -> Option<String> {
314 self.0.lock().latex.clone()
315 }
316
317 pub fn description(&self) -> Option<String> {
319 self.0.lock().description.clone()
320 }
321
322 fn set_name(&self, name: impl Into<String>) {
324 self.0.lock().name = name.into();
325 }
326
327 pub fn set_fixed_value(&self, value: Option<f64>) {
329 {
330 let mut guard = self.0.lock();
331 if let Some(value) = value {
332 guard.fixed = Some(value);
333 guard.initial = Some(value);
334 } else {
335 guard.fixed = None;
336 }
338 }
339 }
340
341 pub fn set_initial(&self, value: f64) {
347 assert!(
348 self.is_free(),
349 "cannot manually set `initial` on a fixed parameter"
350 );
351 self.0.lock().initial = Some(value);
352 }
353
354 pub fn set_bounds<L, U>(&self, min: L, max: U)
356 where
357 L: IntoBound,
358 U: IntoBound,
359 {
360 self.0.lock().bounds = (IntoBound::into_bound(min), IntoBound::into_bound(max));
361 }
362
363 pub fn set_unit(&self, unit: impl Into<String>) {
365 self.0.lock().unit = Some(unit.into());
366 }
367
368 pub fn set_latex(&self, latex: impl Into<String>) {
370 self.0.lock().latex = Some(latex.into());
371 }
372
373 pub fn set_description(&self, description: impl Into<String>) {
375 self.0.lock().description = Some(description.into());
376 }
377
378 pub fn is_free(&self) -> bool {
380 self.0.lock().fixed.is_none()
381 }
382
383 pub fn is_fixed(&self) -> bool {
385 self.0.lock().fixed.is_some()
386 }
387}
388
389#[macro_export]
392macro_rules! parameter {
393 ($name:expr) => {{
394 $crate::amplitudes::Parameter::new($name)
395 }};
396
397 ($name:expr, $value:expr) => {{
398 let p = $crate::amplitudes::Parameter::new($name);
399 p.set_fixed_value(Some($value));
400 p
401 }};
402
403 ($name:expr, $($rest:tt)+) => {{
404 let p = $crate::amplitudes::Parameter::new($name);
405 $crate::parameter!(@parse p, [fixed = false, initial = false]; $($rest)+);
406 p
407 }};
408
409 (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; ) => {};
410
411 (@parse $p:ident, [fixed = false, initial = false]; fixed : $value:expr $(, $($rest:tt)*)?) => {{
412 $p.set_fixed_value(Some($value));
413 $crate::parameter!(@parse $p, [fixed = true, initial = false]; $($($rest)*)?);
414 }};
415
416 (@parse $p:ident, [fixed = false, initial = false]; initial : $value:expr $(, $($rest:tt)*)?) => {{
417 $p.set_initial($value);
418 $crate::parameter!(@parse $p, [fixed = false, initial = true]; $($($rest)*)?);
419 }};
420
421 (@parse $p:ident, [fixed = true, initial = false]; initial : $value:expr $(, $($rest:tt)*)?) => {
422 compile_error!("parameter!: cannot specify both `fixed` and `initial`");
423 };
424
425 (@parse $p:ident, [fixed = false, initial = true]; fixed : $value:expr $(, $($rest:tt)*)?) => {
426 compile_error!("parameter!: cannot specify both `fixed` and `initial`");
427 };
428
429 (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; bounds : ($min:expr, $max:expr) $(, $($rest:tt)*)?) => {{
430 $p.set_bounds($min, $max);
431 $crate::parameter!(@parse $p, [fixed = $f, initial = $i]; $($($rest)*)?);
432 }};
433
434 (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; unit : $value:expr $(, $($rest:tt)*)?) => {{
435 $p.set_unit($value);
436 $crate::parameter!(@parse $p, [fixed = $f, initial = $i]; $($($rest)*)?);
437 }};
438
439 (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; latex : $value:expr $(, $($rest:tt)*)?) => {{
440 $p.set_latex($value);
441 $crate::parameter!(@parse $p, [fixed = $f, initial = $i]; $($($rest)*)?);
442 }};
443
444 (@parse $p:ident, [fixed = $f:tt, initial = $i:tt]; description : $value:expr $(, $($rest:tt)*)?) => {{
445 $p.set_description($value);
446 $crate::parameter!(@parse $p, [fixed = $f, initial = $i]; $($($rest)*)?);
447 }};
448}
449
450#[derive(Default, Debug, Clone)]
452pub struct ParameterMap {
453 parameters: Vec<Parameter>,
454 name_to_index: HashMap<String, usize>,
455}
456
457#[derive(Serialize, Deserialize)]
458struct ParameterMapSerde {
459 parameters: Vec<Parameter>,
460}
461
462impl Index<usize> for ParameterMap {
463 type Output = Parameter;
464
465 fn index(&self, index: usize) -> &Self::Output {
466 &self.parameters[index]
467 }
468}
469
470impl Index<&str> for ParameterMap {
471 type Output = Parameter;
472
473 fn index(&self, key: &str) -> &Self::Output {
474 self.get(key)
475 .unwrap_or_else(|| panic!("parameter '{key}' not found"))
476 }
477}
478
479impl IntoIterator for ParameterMap {
480 type Item = Parameter;
481
482 type IntoIter = std::vec::IntoIter<Parameter>;
483
484 fn into_iter(self) -> Self::IntoIter {
485 self.parameters.into_iter()
486 }
487}
488
489impl<'a> IntoIterator for &'a ParameterMap {
490 type Item = &'a Parameter;
491
492 type IntoIter = std::slice::Iter<'a, Parameter>;
493
494 fn into_iter(self) -> Self::IntoIter {
495 self.parameters.iter()
496 }
497}
498
499impl Serialize for ParameterMap {
500 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
501 where
502 S: serde::Serializer,
503 {
504 ParameterMapSerde {
505 parameters: self.parameters.clone(),
506 }
507 .serialize(serializer)
508 }
509}
510
511impl<'de> Deserialize<'de> for ParameterMap {
512 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
513 where
514 D: serde::Deserializer<'de>,
515 {
516 let serde = ParameterMapSerde::deserialize(deserializer)?;
517 Ok(Self::from_parameters(serde.parameters))
518 }
519}
520
521impl ParameterMap {
522 fn from_parameters(parameters: Vec<Parameter>) -> Self {
523 let name_to_index = parameters
524 .iter()
525 .enumerate()
526 .map(|(index, parameter)| (parameter.name(), index))
527 .collect();
528 Self {
529 parameters,
530 name_to_index,
531 }
532 }
533
534 pub fn register_parameter(&mut self, p: &Parameter) -> LadduResult<ParameterID> {
536 let name = p.name();
537 if name.is_empty() {
538 return Err(LadduError::UnregisteredParameter {
539 name: "<unnamed>".to_string(),
540 reason: "Parameter was not initialized with a name".to_string(),
541 });
542 }
543
544 if let Some((index, existing)) = self.get_indexed(&name) {
545 match (existing.fixed(), p.fixed()) {
546 (Some(a), Some(b)) if (a - b).abs() > f64::EPSILON => {
547 return Err(LadduError::ParameterConflict {
548 name,
549 reason: "conflicting fixed values for the same parameter name".to_string(),
550 });
551 }
552 (Some(_), None) => {
553 return Err(LadduError::ParameterConflict {
554 name,
555 reason: "attempted to use a fixed parameter name as free".to_string(),
556 });
557 }
558 (None, Some(_)) => {
559 return Err(LadduError::ParameterConflict {
560 name,
561 reason: "attempted to use a free parameter name as fixed".to_string(),
562 });
563 }
564 (Some(_), Some(_)) | (None, None) => return Ok(self.parameter_id(index)),
565 }
566 }
567
568 let index = self.parameters.len();
569 self.insert(p.clone());
570 Ok(self.parameter_id(index))
571 }
572 pub fn free_parameter_indices(&self) -> Vec<usize> {
574 (0..self.free().len()).collect()
575 }
576 pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
578 if old == new {
579 return Ok(());
580 }
581 if self.contains_key(new) {
582 return Err(LadduError::ParameterConflict {
583 name: new.to_string(),
584 reason: "rename target already exists".to_string(),
585 });
586 }
587 if let Some(index) = self.index(old) {
588 let parameter = self.parameters[index].clone();
589 parameter.set_name(new);
590 self.name_to_index.remove(old);
591 self.name_to_index.insert(new.to_string(), index);
592 } else {
593 self.assert_parameter_exists(old)?;
594 }
595 Ok(())
596 }
597 pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
599 for (old, new) in mapping {
600 self.rename_parameter(old, new)?;
601 }
602 Ok(())
603 }
604 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
606 self.assert_parameter_exists(name)?;
607 if let Some(parameter) = self.get(name) {
608 parameter.set_fixed_value(Some(value));
609 }
610 Ok(())
611 }
612 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
614 self.assert_parameter_exists(name)?;
615 if let Some(parameter) = self.get(name) {
616 parameter.set_fixed_value(None);
617 }
618 Ok(())
619 }
620 pub fn contains_key(&self, name: &str) -> bool {
622 self.name_to_index.contains_key(name)
623 }
624 pub fn index(&self, name: &str) -> Option<usize> {
626 self.name_to_index.get(name).copied()
627 }
628 pub fn insert(&mut self, parameter: Parameter) -> Option<Parameter> {
630 let name = parameter.name();
631 if let Some(index) = self.index(&name) {
632 Some(std::mem::replace(&mut self.parameters[index], parameter))
633 } else {
634 let index = self.parameters.len();
635 self.parameters.push(parameter);
636 self.name_to_index.insert(name, index);
637 None
638 }
639 }
640 pub fn len(&self) -> usize {
642 self.parameters.len()
643 }
644 pub fn is_empty(&self) -> bool {
646 self.parameters.is_empty()
647 }
648 pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
650 self.parameters.iter()
651 }
652 pub fn get(&self, key: &str) -> Option<&Parameter> {
654 self.index(key).map(|index| &self.parameters[index])
655 }
656 pub fn get_indexed(&self, key: &str) -> Option<(usize, &Parameter)> {
658 self.index(key)
659 .map(|index| (index, &self.parameters[index]))
660 }
661 pub fn names(&self) -> Vec<String> {
663 self.parameters.iter().map(Parameter::name).collect()
664 }
665 pub fn filter(&self, predicate: impl Fn(&Parameter) -> bool) -> Self {
667 Self::from_parameters(
668 self.parameters
669 .iter()
670 .filter(|parameter| predicate(parameter))
671 .cloned()
672 .collect(),
673 )
674 }
675 pub fn free(&self) -> Self {
677 self.filter(|p| p.is_free())
678 }
679 pub fn fixed(&self) -> Self {
681 self.filter(|p| p.is_fixed())
682 }
683 pub fn initialized(&self) -> Self {
685 self.filter(|p| p.initial().is_some())
686 }
687 pub fn uninitialized(&self) -> Self {
689 self.filter(|p| p.initial().is_none())
690 }
691
692 pub fn assemble(&self, free_values: &[f64]) -> LadduResult<Parameters> {
696 let expected_free = self.free().len();
697 let n_fixed = self.fixed().len();
698 let mut values = vec![0.0; expected_free + n_fixed];
699 let mut storage_to_assembled = vec![0; self.len()];
700 let mut free_iter = free_values.iter();
701 let mut free_index = 0;
702 let mut fixed_index = expected_free;
703 for (storage_index, parameter) in self.parameters.iter().enumerate() {
704 if let Some(value) = parameter.fixed() {
705 values[fixed_index] = value;
706 storage_to_assembled[storage_index] = fixed_index;
707 fixed_index += 1;
708 } else if let Some(value) = free_iter.next() {
709 values[free_index] = *value;
710 storage_to_assembled[storage_index] = free_index;
711 free_index += 1;
712 } else {
713 return Err(LadduError::LengthMismatch {
714 context: "parameter values".to_string(),
715 expected: expected_free,
716 actual: free_values.len(),
717 });
718 }
719 }
720 if free_iter.next().is_some() {
721 return Err(LadduError::LengthMismatch {
722 context: "parameter values".to_string(),
723 expected: expected_free,
724 actual: free_values.len(),
725 });
726 }
727 Ok(Parameters::new(values, expected_free, storage_to_assembled))
728 }
729
730 pub fn merge(&self, other: &Self) -> (Self, Vec<usize>, Vec<usize>) {
734 let mut merged = self.clone();
735 let mut right_map = Vec::with_capacity(other.len());
736 for parameter in other {
737 let idx = merged.ensure_parameter(parameter.clone());
738 right_map.push(idx);
739 }
740 let left_map: Vec<usize> = (0..self.len())
741 .map(|index| merged.assembled_index(index))
742 .collect();
743 let right_map = right_map
744 .into_iter()
745 .map(|index| merged.assembled_index(index))
746 .collect();
747 (merged, left_map, right_map)
748 }
749
750 pub fn extend_from(&self, other: &Self) -> (Self, Vec<usize>) {
754 let mut merged = self.clone();
755 let mut indices = Vec::with_capacity(other.len());
756 for parameter in other {
757 let idx = merged.ensure_parameter(parameter.clone());
758 indices.push(idx);
759 }
760 let indices = indices
761 .into_iter()
762 .map(|index| merged.assembled_index(index))
763 .collect();
764 (merged, indices)
765 }
766
767 fn ensure_parameter(&mut self, parameter: Parameter) -> usize {
768 let name = parameter.name();
769 if let Some(idx) = self.index(&name) {
770 return idx;
771 }
772 let idx = self.len();
773 self.insert(parameter);
774 idx
775 }
776
777 fn assembled_index(&self, storage_index: usize) -> usize {
778 let n_free = self
779 .parameters
780 .iter()
781 .filter(|parameter| parameter.is_free())
782 .count();
783 let preceding_in_group = self.parameters[..storage_index]
784 .iter()
785 .filter(|parameter| self.parameters[storage_index].is_free() == parameter.is_free())
786 .count();
787 if self.parameters[storage_index].is_free() {
788 preceding_in_group
789 } else {
790 n_free + preceding_in_group
791 }
792 }
793
794 fn parameter_id(&self, storage_index: usize) -> ParameterID {
795 if self.parameters[storage_index].is_fixed() {
796 ParameterID::Constant(storage_index)
797 } else {
798 ParameterID::Parameter(storage_index)
799 }
800 }
801
802 fn assert_parameter_exists(&self, name: &str) -> LadduResult<()> {
803 if self.contains_key(name) {
804 Ok(())
805 } else {
806 Err(LadduError::UnregisteredParameter {
807 name: name.to_string(),
808 reason: "parameter not found".to_string(),
809 })
810 }
811 }
812}
813
814#[typetag::serde(tag = "type")]
822pub trait Amplitude: DynClone + Send + Sync {
823 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID>;
832 fn semantic_key(&self) -> Option<AmplitudeSemanticKey> {
839 None
840 }
841 fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
847 Ok(())
848 }
849 fn dependence_hint(&self) -> ExpressionDependence {
853 ExpressionDependence::Mixed
854 }
855 fn real_valued_hint(&self) -> bool {
861 false
862 }
863 #[allow(unused_variables)]
868 fn precompute(&self, event: &NamedEventView<'_>, cache: &mut Cache) {}
869
870 #[cfg(feature = "rayon")]
872 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
873 resources
874 .caches
875 .par_iter_mut()
876 .enumerate()
877 .for_each(|(event_index, cache)| {
878 let event = dataset.event_view(event_index);
879 self.precompute(&event, cache);
880 });
881 }
882
883 #[cfg(not(feature = "rayon"))]
885 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
886 dataset.for_each_named_event_local(|event_index, event| {
887 let cache = &mut resources.caches[event_index];
888 self.precompute(&event, cache);
889 });
890 }
891 fn compute(&self, parameters: &Parameters, cache: &Cache) -> Complex64;
894
895 fn compute_gradient(
910 &self,
911 parameters: &Parameters,
912 cache: &Cache,
913 gradient: &mut DVector<Complex64>,
914 ) {
915 self.central_difference_with_indices(
916 &Vec::from_iter(0..parameters.len()),
917 parameters,
918 cache,
919 gradient,
920 )
921 }
922
923 fn central_difference_with_indices(
929 &self,
930 indices: &[usize],
931 parameters: &Parameters,
932 cache: &Cache,
933 gradient: &mut DVector<Complex64>,
934 ) {
935 let x = parameters.values().to_owned();
936 let h: DVector<f64> = x
937 .iter()
938 .map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
939 .collect::<Vec<_>>()
940 .into();
941 for i in indices {
942 let mut x_plus = x.clone();
943 let mut x_minus = x.clone();
944 x_plus[*i] += h[*i];
945 x_minus[*i] -= h[*i];
946 let f_plus = self.compute(¶meters.with_values(x_plus), cache);
947 let f_minus = self.compute(¶meters.with_values(x_minus), cache);
948 gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
949 }
950 }
951
952 fn into_expression(self) -> LadduResult<Expression>
957 where
958 Self: Sized + 'static,
959 {
960 Expression::from_amplitude(Box::new(self))
961 }
962}
963
964pub fn central_difference<F: Fn(&[f64]) -> f64>(parameters: &[f64], func: F) -> DVector<f64> {
966 let mut gradient = DVector::zeros(parameters.len());
967 let x = parameters.to_owned();
968 let h: DVector<f64> = x
969 .iter()
970 .map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
971 .collect::<Vec<_>>()
972 .into();
973 for i in 0..parameters.len() {
974 let mut x_plus = x.clone();
975 let mut x_minus = x.clone();
976 x_plus[i] += h[i];
977 x_minus[i] -= h[i];
978 let f_plus = func(&x_plus);
979 let f_minus = func(&x_minus);
980 gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
981 }
982 gradient
983}
984
985dyn_clone::clone_trait_object!(Amplitude);
986
987#[derive(Debug)]
989pub struct AmplitudeValues(pub Vec<Complex64>);
990
991#[derive(Debug)]
993pub struct GradientValues(pub usize, pub Vec<DVector<Complex64>>);
994
995#[derive(Clone, Default, Debug, Serialize, Deserialize)]
998pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
999
1000impl Display for AmplitudeID {
1001 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1002 write!(f, "{}(id={})", self.0, self.1)
1003 }
1004}
1005
1006#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
1008pub struct AmplitudeSemanticField {
1009 name: String,
1010 value: String,
1011}
1012
1013impl AmplitudeSemanticField {
1014 pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
1016 Self {
1017 name: name.into(),
1018 value: value.into(),
1019 }
1020 }
1021}
1022
1023#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
1028pub struct AmplitudeSemanticKey {
1029 kind: String,
1030 fields: Vec<AmplitudeSemanticField>,
1031}
1032
1033impl AmplitudeSemanticKey {
1034 pub fn new(kind: impl Into<String>) -> Self {
1036 Self {
1037 kind: kind.into(),
1038 fields: Vec::new(),
1039 }
1040 }
1041
1042 pub fn with_field(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
1044 self.fields.push(AmplitudeSemanticField::new(name, value));
1045 self
1046 }
1047
1048 fn field_value(&self, name: &str) -> Option<&str> {
1049 self.fields
1050 .iter()
1051 .find(|field| field.name == name)
1052 .map(|field| field.value.as_str())
1053 }
1054
1055 fn mismatch_summary(&self, other: &Self) -> String {
1056 let mut mismatches = Vec::new();
1057 if self.kind != other.kind {
1058 mismatches.push(format!(
1059 "kind differs (existing: {:?}, incoming: {:?})",
1060 self.kind, other.kind
1061 ));
1062 }
1063 for field in &self.fields {
1064 match other.field_value(&field.name) {
1065 Some(value) if value == field.value => {}
1066 Some(value) => mismatches.push(format!(
1067 "{} differs (existing: {}, incoming: {})",
1068 field.name, field.value, value
1069 )),
1070 None => mismatches.push(format!(
1071 "{} is missing from the incoming key (existing: {})",
1072 field.name, field.value
1073 )),
1074 }
1075 }
1076 for field in &other.fields {
1077 if self.field_value(&field.name).is_none() {
1078 mismatches.push(format!(
1079 "{} is missing from the existing key (incoming: {})",
1080 field.name, field.value
1081 ));
1082 }
1083 }
1084 if mismatches.is_empty() {
1085 "semantic keys differ".to_string()
1086 } else {
1087 mismatches.join("; ")
1088 }
1089 }
1090}
1091
1092#[allow(missing_docs)]
1094#[derive(Clone, Serialize, Deserialize)]
1095pub struct Expression {
1096 registry: ExpressionRegistry,
1097 tree: ExpressionNode,
1098}
1099
1100impl ReadWrite for Expression {
1101 fn create_null() -> Self {
1102 Self {
1103 registry: ExpressionRegistry::default(),
1104 tree: ExpressionNode::default(),
1105 }
1106 }
1107}
1108
1109#[derive(Clone, Serialize, Deserialize)]
1110#[allow(missing_docs)]
1111#[derive(Default)]
1112pub struct ExpressionRegistry {
1113 amplitudes: Vec<Box<dyn Amplitude>>,
1114 amplitude_names: Vec<String>,
1115 amplitude_ids: Vec<u64>,
1116 resources: Resources,
1117}
1118
1119impl ExpressionRegistry {
1120 fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
1121 let mut resources = Resources::default();
1122 let aid = amplitude.register(&mut resources)?;
1123 let amp_id = next_amplitude_id();
1124 Ok(Self {
1125 amplitudes: vec![amplitude],
1126 amplitude_names: vec![aid.0],
1127 amplitude_ids: vec![amp_id],
1128 resources,
1129 })
1130 }
1131
1132 fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
1133 let mut resources = Resources::default();
1134 let mut amplitudes = Vec::new();
1135 let mut amplitude_names = Vec::new();
1136 let mut amplitude_ids = Vec::new();
1137 let mut amplitude_semantic_keys = Vec::new();
1138 let mut name_to_index = HashMap::new();
1139
1140 let mut left_map = Vec::with_capacity(self.amplitudes.len());
1141 for ((amp, name), amp_id) in self
1142 .amplitudes
1143 .iter()
1144 .zip(&self.amplitude_names)
1145 .zip(&self.amplitude_ids)
1146 {
1147 let semantic_key = amp.semantic_key();
1148 let mut cloned_amp = dyn_clone::clone_box(&**amp);
1149 let aid = cloned_amp.register(&mut resources)?;
1150 amplitudes.push(cloned_amp);
1151 amplitude_names.push(name.clone());
1152 amplitude_ids.push(*amp_id);
1153 amplitude_semantic_keys.push(semantic_key);
1154 name_to_index.insert(name.clone(), aid.1);
1155 left_map.push(aid.1);
1156 }
1157
1158 let mut right_map = Vec::with_capacity(other.amplitudes.len());
1159 for ((amp, name), amp_id) in other
1160 .amplitudes
1161 .iter()
1162 .zip(&other.amplitude_names)
1163 .zip(&other.amplitude_ids)
1164 {
1165 if let Some(existing) = name_to_index.get(name) {
1166 let existing_amp_id = amplitude_ids[*existing];
1167 let incoming_semantic_key = amp.semantic_key();
1168 if existing_amp_id != *amp_id {
1169 match (&litude_semantic_keys[*existing], &incoming_semantic_key) {
1170 (Some(existing_key), Some(incoming_key))
1171 if existing_key == incoming_key => {}
1172 (Some(existing_key), Some(incoming_key)) => {
1173 return Err(LadduError::Custom(format!(
1174 "Amplitude name \"{name}\" refers to different underlying amplitudes; semantic keys differ: {}",
1175 existing_key.mismatch_summary(incoming_key)
1176 )));
1177 }
1178 _ => {
1179 return Err(LadduError::Custom(format!(
1180 "Amplitude name \"{name}\" refers to different underlying amplitudes; rename to avoid conflicts"
1181 )));
1182 }
1183 }
1184 }
1185 right_map.push(*existing);
1186 continue;
1187 }
1188 let semantic_key = amp.semantic_key();
1189 let mut cloned_amp = dyn_clone::clone_box(&**amp);
1190 let aid = cloned_amp.register(&mut resources)?;
1191 amplitudes.push(cloned_amp);
1192 amplitude_names.push(name.clone());
1193 amplitude_ids.push(*amp_id);
1194 amplitude_semantic_keys.push(semantic_key);
1195 name_to_index.insert(name.clone(), aid.1);
1196 right_map.push(aid.1);
1197 }
1198
1199 Ok((
1200 Self {
1201 amplitudes,
1202 amplitude_names,
1203 amplitude_ids,
1204 resources,
1205 },
1206 left_map,
1207 right_map,
1208 ))
1209 }
1210}
1211
1212#[allow(missing_docs)]
1214#[derive(Clone, Serialize, Deserialize, Default, Debug)]
1215pub enum ExpressionNode {
1216 #[default]
1217 Zero,
1219 One,
1221 Constant(f64),
1223 ComplexConstant(Complex64),
1225 Amp(usize),
1227 Add(Box<ExpressionNode>, Box<ExpressionNode>),
1229 Sub(Box<ExpressionNode>, Box<ExpressionNode>),
1231 Mul(Box<ExpressionNode>, Box<ExpressionNode>),
1233 Div(Box<ExpressionNode>, Box<ExpressionNode>),
1235 Neg(Box<ExpressionNode>),
1237 Real(Box<ExpressionNode>),
1239 Imag(Box<ExpressionNode>),
1241 Conj(Box<ExpressionNode>),
1243 NormSqr(Box<ExpressionNode>),
1245 Sqrt(Box<ExpressionNode>),
1246 Pow(Box<ExpressionNode>, Box<ExpressionNode>),
1247 PowI(Box<ExpressionNode>, i32),
1248 PowF(Box<ExpressionNode>, f64),
1249 Exp(Box<ExpressionNode>),
1250 Sin(Box<ExpressionNode>),
1251 Cos(Box<ExpressionNode>),
1252 Log(Box<ExpressionNode>),
1253 Cis(Box<ExpressionNode>),
1254}
1255
1256#[derive(Clone, Debug)]
1257struct ExpressionProgram {
1264 ops: Vec<ExpressionOp>,
1265 slot_count: usize,
1266 root_slot: usize,
1267}
1268
1269#[derive(Clone, Debug)]
1270enum ExpressionOp {
1271 LoadZero {
1272 dst: usize,
1273 },
1274 LoadOne {
1275 dst: usize,
1276 },
1277 LoadConstant {
1278 dst: usize,
1279 value: f64,
1280 },
1281 LoadComplexConstant {
1282 dst: usize,
1283 value: Complex64,
1284 },
1285 LoadAmp {
1286 dst: usize,
1287 amp_idx: usize,
1288 },
1289 Add {
1290 dst: usize,
1291 left: usize,
1292 right: usize,
1293 },
1294 Sub {
1295 dst: usize,
1296 left: usize,
1297 right: usize,
1298 },
1299 Mul {
1300 dst: usize,
1301 left: usize,
1302 right: usize,
1303 },
1304 Div {
1305 dst: usize,
1306 left: usize,
1307 right: usize,
1308 },
1309 Neg {
1310 dst: usize,
1311 input: usize,
1312 },
1313 Real {
1314 dst: usize,
1315 input: usize,
1316 },
1317 Imag {
1318 dst: usize,
1319 input: usize,
1320 },
1321 Conj {
1322 dst: usize,
1323 input: usize,
1324 },
1325 NormSqr {
1326 dst: usize,
1327 input: usize,
1328 },
1329 Sqrt {
1330 dst: usize,
1331 input: usize,
1332 },
1333 Pow {
1334 dst: usize,
1335 value: usize,
1336 power: usize,
1337 },
1338 PowI {
1339 dst: usize,
1340 input: usize,
1341 power: i32,
1342 },
1343 PowF {
1344 dst: usize,
1345 input: usize,
1346 power: f64,
1347 },
1348 Exp {
1349 dst: usize,
1350 input: usize,
1351 },
1352 Sin {
1353 dst: usize,
1354 input: usize,
1355 },
1356 Cos {
1357 dst: usize,
1358 input: usize,
1359 },
1360 Log {
1361 dst: usize,
1362 input: usize,
1363 },
1364 Cis {
1365 dst: usize,
1366 input: usize,
1367 },
1368}
1369
1370#[derive(Default)]
1371struct ExpressionProgramBuilder {
1372 ops: Vec<ExpressionOp>,
1373 next_slot: usize,
1374}
1375
1376impl ExpressionProgramBuilder {
1377 fn alloc_slot(&mut self) -> usize {
1378 let slot = self.next_slot;
1379 self.next_slot += 1;
1380 slot
1381 }
1382
1383 fn build(self, root: usize) -> ExpressionProgram {
1384 ExpressionProgram {
1385 ops: self.ops,
1386 slot_count: self.next_slot,
1387 root_slot: root,
1388 }
1389 }
1390
1391 fn emit(&mut self, op: ExpressionOp) {
1392 self.ops.push(op);
1393 }
1394
1395 fn compile(&mut self, node: &ExpressionNode) -> usize {
1396 match node {
1397 ExpressionNode::Zero => {
1398 let dst = self.alloc_slot();
1399 self.emit(ExpressionOp::LoadZero { dst });
1400 dst
1401 }
1402 ExpressionNode::One => {
1403 let dst = self.alloc_slot();
1404 self.emit(ExpressionOp::LoadOne { dst });
1405 dst
1406 }
1407 ExpressionNode::Constant(value) => {
1408 let dst = self.alloc_slot();
1409 self.emit(ExpressionOp::LoadConstant { dst, value: *value });
1410 dst
1411 }
1412 ExpressionNode::ComplexConstant(value) => {
1413 let dst = self.alloc_slot();
1414 self.emit(ExpressionOp::LoadComplexConstant { dst, value: *value });
1415 dst
1416 }
1417 ExpressionNode::Amp(idx) => {
1418 let dst = self.alloc_slot();
1419 self.emit(ExpressionOp::LoadAmp { dst, amp_idx: *idx });
1420 dst
1421 }
1422 ExpressionNode::Add(a, b) => {
1423 let left = self.compile(a);
1424 let right = self.compile(b);
1425 let dst = self.alloc_slot();
1426 self.emit(ExpressionOp::Add { dst, left, right });
1427 dst
1428 }
1429 ExpressionNode::Sub(a, b) => {
1430 let left = self.compile(a);
1431 let right = self.compile(b);
1432 let dst = self.alloc_slot();
1433 self.emit(ExpressionOp::Sub { dst, left, right });
1434 dst
1435 }
1436 ExpressionNode::Mul(a, b) => {
1437 let left = self.compile(a);
1438 let right = self.compile(b);
1439 let dst = self.alloc_slot();
1440 self.emit(ExpressionOp::Mul { dst, left, right });
1441 dst
1442 }
1443 ExpressionNode::Div(a, b) => {
1444 let left = self.compile(a);
1445 let right = self.compile(b);
1446 let dst = self.alloc_slot();
1447 self.emit(ExpressionOp::Div { dst, left, right });
1448 dst
1449 }
1450 ExpressionNode::Neg(a) => {
1451 let input = self.compile(a);
1452 let dst = self.alloc_slot();
1453 self.emit(ExpressionOp::Neg { dst, input });
1454 dst
1455 }
1456 ExpressionNode::Real(a) => {
1457 let input = self.compile(a);
1458 let dst = self.alloc_slot();
1459 self.emit(ExpressionOp::Real { dst, input });
1460 dst
1461 }
1462 ExpressionNode::Imag(a) => {
1463 let input = self.compile(a);
1464 let dst = self.alloc_slot();
1465 self.emit(ExpressionOp::Imag { dst, input });
1466 dst
1467 }
1468 ExpressionNode::Conj(a) => {
1469 let input = self.compile(a);
1470 let dst = self.alloc_slot();
1471 self.emit(ExpressionOp::Conj { dst, input });
1472 dst
1473 }
1474 ExpressionNode::NormSqr(a) => {
1475 let input = self.compile(a);
1476 let dst = self.alloc_slot();
1477 self.emit(ExpressionOp::NormSqr { dst, input });
1478 dst
1479 }
1480 ExpressionNode::Sqrt(a) => {
1481 let input = self.compile(a);
1482 let dst = self.alloc_slot();
1483 self.emit(ExpressionOp::Sqrt { dst, input });
1484 dst
1485 }
1486 ExpressionNode::Pow(a, b) => {
1487 let value = self.compile(a);
1488 let power = self.compile(b);
1489 let dst = self.alloc_slot();
1490 self.emit(ExpressionOp::Pow { dst, value, power });
1491 dst
1492 }
1493 ExpressionNode::PowI(a, power) => {
1494 let input = self.compile(a);
1495 let dst = self.alloc_slot();
1496 self.emit(ExpressionOp::PowI {
1497 dst,
1498 input,
1499 power: *power,
1500 });
1501 dst
1502 }
1503 ExpressionNode::PowF(a, power) => {
1504 let input = self.compile(a);
1505 let dst = self.alloc_slot();
1506 self.emit(ExpressionOp::PowF {
1507 dst,
1508 input,
1509 power: *power,
1510 });
1511 dst
1512 }
1513 ExpressionNode::Exp(a) => {
1514 let input = self.compile(a);
1515 let dst = self.alloc_slot();
1516 self.emit(ExpressionOp::Exp { dst, input });
1517 dst
1518 }
1519 ExpressionNode::Sin(a) => {
1520 let input = self.compile(a);
1521 let dst = self.alloc_slot();
1522 self.emit(ExpressionOp::Sin { dst, input });
1523 dst
1524 }
1525 ExpressionNode::Cos(a) => {
1526 let input = self.compile(a);
1527 let dst = self.alloc_slot();
1528 self.emit(ExpressionOp::Cos { dst, input });
1529 dst
1530 }
1531 ExpressionNode::Log(a) => {
1532 let input = self.compile(a);
1533 let dst = self.alloc_slot();
1534 self.emit(ExpressionOp::Log { dst, input });
1535 dst
1536 }
1537 ExpressionNode::Cis(a) => {
1538 let input = self.compile(a);
1539 let dst = self.alloc_slot();
1540 self.emit(ExpressionOp::Cis { dst, input });
1541 dst
1542 }
1543 }
1544 }
1545}
1546
1547impl ExpressionProgram {
1548 fn from_node(node: &ExpressionNode) -> Self {
1549 let mut builder = ExpressionProgramBuilder::default();
1550 let root = builder.compile(node);
1551 builder.build(root)
1552 }
1553
1554 fn fill_values(&self, amplitude_values: &[Complex64], slots: &mut [Complex64]) {
1555 debug_assert!(slots.len() >= self.slot_count);
1556 for op in &self.ops {
1557 match *op {
1558 ExpressionOp::LoadZero { dst } => slots[dst] = Complex64::ZERO,
1559 ExpressionOp::LoadOne { dst } => slots[dst] = Complex64::ONE,
1560 ExpressionOp::LoadConstant { dst, value } => slots[dst] = Complex64::from(value),
1561 ExpressionOp::LoadComplexConstant { dst, value } => slots[dst] = value,
1562 ExpressionOp::LoadAmp { dst, amp_idx } => {
1563 slots[dst] = amplitude_values.get(amp_idx).copied().unwrap_or_default();
1564 }
1565 ExpressionOp::Add { dst, left, right } => {
1566 slots[dst] = slots[left] + slots[right];
1567 }
1568 ExpressionOp::Sub { dst, left, right } => {
1569 slots[dst] = slots[left] - slots[right];
1570 }
1571 ExpressionOp::Mul { dst, left, right } => {
1572 slots[dst] = slots[left] * slots[right];
1573 }
1574 ExpressionOp::Div { dst, left, right } => {
1575 slots[dst] = slots[left] / slots[right];
1576 }
1577 ExpressionOp::Neg { dst, input } => {
1578 slots[dst] = -slots[input];
1579 }
1580 ExpressionOp::Real { dst, input } => {
1581 slots[dst] = Complex64::new(slots[input].re, 0.0);
1582 }
1583 ExpressionOp::Imag { dst, input } => {
1584 slots[dst] = Complex64::new(slots[input].im, 0.0);
1585 }
1586 ExpressionOp::Conj { dst, input } => {
1587 slots[dst] = slots[input].conj();
1588 }
1589 ExpressionOp::NormSqr { dst, input } => {
1590 slots[dst] = Complex64::new(slots[input].norm_sqr(), 0.0);
1591 }
1592 ExpressionOp::Sqrt { dst, input } => {
1593 slots[dst] = slots[input].sqrt();
1594 }
1595 ExpressionOp::Pow { dst, value, power } => {
1596 slots[dst] = slots[value].powc(slots[power]);
1597 }
1598 ExpressionOp::PowI { dst, input, power } => {
1599 slots[dst] = slots[input].powi(power);
1600 }
1601 ExpressionOp::PowF { dst, input, power } => {
1602 slots[dst] = slots[input].powc(Complex64::new(power, 0.0));
1603 }
1604 ExpressionOp::Exp { dst, input } => {
1605 slots[dst] = slots[input].exp();
1606 }
1607 ExpressionOp::Sin { dst, input } => {
1608 slots[dst] = slots[input].sin();
1609 }
1610 ExpressionOp::Cos { dst, input } => {
1611 slots[dst] = slots[input].cos();
1612 }
1613 ExpressionOp::Log { dst, input } => {
1614 slots[dst] = slots[input].ln();
1615 }
1616 ExpressionOp::Cis { dst, input } => {
1617 slots[dst] = (Complex64::new(0.0, 1.0) * slots[input]).exp();
1618 }
1619 }
1620 }
1621 }
1622
1623 fn evaluate_into(&self, amplitude_values: &[Complex64], slots: &mut [Complex64]) -> Complex64 {
1624 if self.slot_count == 0 {
1625 return Complex64::ZERO;
1626 }
1627 self.fill_values(amplitude_values, slots);
1628 slots[self.root_slot]
1629 }
1630
1631 pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
1632 if self.slot_count == 0 {
1633 return Complex64::ZERO;
1634 }
1635 let mut slots = vec![Complex64::ZERO; self.slot_count];
1636 self.evaluate_into(amplitude_values, &mut slots)
1637 }
1638
1639 pub fn evaluate_gradient_into(
1640 &self,
1641 amplitude_values: &[Complex64],
1642 gradient_values: &[DVector<Complex64>],
1643 value_slots: &mut [Complex64],
1644 gradient_slots: &mut [DVector<Complex64>],
1645 ) -> DVector<Complex64> {
1646 if self.slot_count == 0 {
1647 let dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
1648 return DVector::zeros(dim);
1649 }
1650 self.fill_values(amplitude_values, value_slots);
1651 self.fill_gradients(gradient_values, value_slots, gradient_slots);
1652 gradient_slots[self.root_slot].clone()
1653 }
1654
1655 pub fn evaluate_gradient(
1656 &self,
1657 amplitude_values: &[Complex64],
1658 gradient_values: &[DVector<Complex64>],
1659 ) -> DVector<Complex64> {
1660 let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
1661 let mut value_slots = vec![Complex64::ZERO; self.slot_count];
1662 let mut gradient_slots: Vec<DVector<Complex64>> = (0..self.slot_count)
1663 .map(|_| DVector::zeros(grad_dim))
1664 .collect();
1665 self.evaluate_gradient_into(
1666 amplitude_values,
1667 gradient_values,
1668 &mut value_slots,
1669 &mut gradient_slots,
1670 )
1671 }
1672
1673 fn fill_gradients(
1674 &self,
1675 amplitude_gradients: &[DVector<Complex64>],
1676 values: &[Complex64],
1677 gradients: &mut [DVector<Complex64>],
1678 ) {
1679 debug_assert!(gradients.len() >= self.slot_count);
1680 debug_assert!(values.len() >= self.slot_count);
1681 fn borrow_dst(
1682 gradients: &mut [DVector<Complex64>],
1683 dst: usize,
1684 ) -> (&[DVector<Complex64>], &mut DVector<Complex64>) {
1685 let (before, tail) = gradients.split_at_mut(dst);
1686 let (dst_ref, _) = tail.split_first_mut().expect("dst slot should exist");
1687 (before, dst_ref)
1688 }
1689 for op in &self.ops {
1690 match *op {
1691 ExpressionOp::LoadZero { dst }
1692 | ExpressionOp::LoadOne { dst }
1693 | ExpressionOp::LoadConstant { dst, .. }
1694 | ExpressionOp::LoadComplexConstant { dst, .. } => {
1695 let (_, dst_grad) = borrow_dst(gradients, dst);
1696 for item in dst_grad.iter_mut() {
1697 *item = Complex64::ZERO;
1698 }
1699 }
1700 ExpressionOp::LoadAmp { dst, amp_idx } => {
1701 let (_, dst_grad) = borrow_dst(gradients, dst);
1702 if let Some(source) = amplitude_gradients.get(amp_idx) {
1703 dst_grad.clone_from(source);
1704 } else {
1705 for item in dst_grad.iter_mut() {
1706 *item = Complex64::ZERO;
1707 }
1708 }
1709 }
1710 ExpressionOp::Add { dst, left, right } => {
1711 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1712 dst_grad.clone_from(&before_dst[left]);
1713 for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
1714 {
1715 *dst_item += *right_item;
1716 }
1717 }
1718 ExpressionOp::Sub { dst, left, right } => {
1719 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1720 dst_grad.clone_from(&before_dst[left]);
1721 for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
1722 {
1723 *dst_item -= *right_item;
1724 }
1725 }
1726 ExpressionOp::Mul { dst, left, right } => {
1727 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1728 let f_left = values[left];
1729 let f_right = values[right];
1730 dst_grad.clone_from(&before_dst[right]);
1731 for item in dst_grad.iter_mut() {
1732 *item *= f_left;
1733 }
1734 for (dst_item, left_item) in dst_grad.iter_mut().zip(before_dst[left].iter()) {
1735 *dst_item += *left_item * f_right;
1736 }
1737 }
1738 ExpressionOp::Div { dst, left, right } => {
1739 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1740 let f_left = values[left];
1741 let f_right = values[right];
1742 let denom = f_right * f_right;
1743 dst_grad.clone_from(&before_dst[left]);
1744 for item in dst_grad.iter_mut() {
1745 *item *= f_right;
1746 }
1747 for (dst_item, right_item) in dst_grad.iter_mut().zip(before_dst[right].iter())
1748 {
1749 *dst_item -= *right_item * f_left;
1750 }
1751 for item in dst_grad.iter_mut() {
1752 *item /= denom;
1753 }
1754 }
1755 ExpressionOp::Neg { dst, input } => {
1756 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1757 dst_grad.clone_from(&before_dst[input]);
1758 for item in dst_grad.iter_mut() {
1759 *item = -*item;
1760 }
1761 }
1762 ExpressionOp::Real { dst, input } => {
1763 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1764 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1765 {
1766 *dst_item = Complex64::new(input_item.re, 0.0);
1767 }
1768 }
1769 ExpressionOp::Imag { dst, input } => {
1770 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1771 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1772 {
1773 *dst_item = Complex64::new(input_item.im, 0.0);
1774 }
1775 }
1776 ExpressionOp::Conj { dst, input } => {
1777 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1778 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1779 {
1780 *dst_item = input_item.conj();
1781 }
1782 }
1783 ExpressionOp::NormSqr { dst, input } => {
1784 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1785 let conj_value = values[input].conj();
1786 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1787 {
1788 *dst_item = Complex64::new(2.0 * (*input_item * conj_value).re, 0.0);
1789 }
1790 }
1791 ExpressionOp::Sqrt { dst, input } => {
1792 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1793 let factor = Complex64::new(0.5, 0.0) / values[input].sqrt();
1794 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1795 {
1796 *dst_item = *input_item * factor;
1797 }
1798 }
1799 ExpressionOp::Pow { dst, value, power } => {
1800 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1801 let base = values[value];
1802 let exponent = values[power];
1803 let output = values[dst];
1804 for ((dst_item, value_item), power_item) in dst_grad
1805 .iter_mut()
1806 .zip(before_dst[value].iter())
1807 .zip(before_dst[power].iter())
1808 {
1809 *dst_item =
1810 output * (*power_item * base.ln() + exponent * *value_item / base);
1811 }
1812 }
1813 ExpressionOp::PowI { dst, input, power } => {
1814 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1815 let factor = match power {
1816 0 => Complex64::ZERO,
1817 1 => Complex64::ONE,
1818 _ => {
1819 let base = values[input];
1820 let multiplier = Complex64::new(power as f64, 0.0);
1821 if let Some(derivative_power) = power.checked_sub(1) {
1822 multiplier * base.powi(derivative_power)
1823 } else {
1824 multiplier * base.powi(power) / base
1825 }
1826 }
1827 };
1828 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1829 {
1830 *dst_item = *input_item * factor;
1831 }
1832 }
1833 ExpressionOp::PowF { dst, input, power } => {
1834 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1835 let factor = if power == 0.0 {
1836 Complex64::ZERO
1837 } else {
1838 Complex64::new(power, 0.0)
1839 * values[input].powc(Complex64::new(power - 1.0, 0.0))
1840 };
1841 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1842 {
1843 *dst_item = *input_item * factor;
1844 }
1845 }
1846 ExpressionOp::Exp { dst, input } => {
1847 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1848 let output = values[dst];
1849 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1850 {
1851 *dst_item = *input_item * output;
1852 }
1853 }
1854 ExpressionOp::Sin { dst, input } => {
1855 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1856 let factor = values[input].cos();
1857 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1858 {
1859 *dst_item = *input_item * factor;
1860 }
1861 }
1862 ExpressionOp::Cos { dst, input } => {
1863 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1864 let factor = -values[input].sin();
1865 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1866 {
1867 *dst_item = *input_item * factor;
1868 }
1869 }
1870 ExpressionOp::Log { dst, input } => {
1871 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1872 let factor = Complex64::ONE / values[input];
1873 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1874 {
1875 *dst_item = *input_item * factor;
1876 }
1877 }
1878 ExpressionOp::Cis { dst, input } => {
1879 let (before_dst, dst_grad) = borrow_dst(gradients, dst);
1880 let factor = Complex64::new(0.0, 1.0) * values[dst];
1881 for (dst_item, input_item) in dst_grad.iter_mut().zip(before_dst[input].iter())
1882 {
1883 *dst_item = *input_item * factor;
1884 }
1885 }
1886 }
1887 }
1888 }
1889}
1890
1891impl ExpressionNode {
1892 fn remap(&self, mapping: &[usize]) -> Self {
1893 match self {
1894 Self::Amp(idx) => Self::Amp(mapping[*idx]),
1895 Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1896 Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1897 Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1898 Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1899 Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
1900 Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
1901 Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
1902 Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
1903 Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
1904 Self::Zero => Self::Zero,
1905 Self::One => Self::One,
1906 Self::Constant(v) => Self::Constant(*v),
1907 Self::ComplexConstant(v) => Self::ComplexConstant(*v),
1908 Self::Sqrt(a) => Self::Sqrt(Box::new(a.remap(mapping))),
1909 Self::Pow(a, b) => Self::Pow(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
1910 Self::PowI(a, power) => Self::PowI(Box::new(a.remap(mapping)), *power),
1911 Self::PowF(a, power) => Self::PowF(Box::new(a.remap(mapping)), *power),
1912 Self::Exp(a) => Self::Exp(Box::new(a.remap(mapping))),
1913 Self::Sin(a) => Self::Sin(Box::new(a.remap(mapping))),
1914 Self::Cos(a) => Self::Cos(Box::new(a.remap(mapping))),
1915 Self::Log(a) => Self::Log(Box::new(a.remap(mapping))),
1916 Self::Cis(a) => Self::Cis(Box::new(a.remap(mapping))),
1917 }
1918 }
1919
1920 fn program(&self) -> ExpressionProgram {
1921 ExpressionProgram::from_node(self)
1922 }
1923
1924 pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
1926 self.program().evaluate(amplitude_values)
1927 }
1928
1929 pub fn evaluate_gradient(
1931 &self,
1932 amplitude_values: &[Complex64],
1933 gradient_values: &[DVector<Complex64>],
1934 ) -> DVector<Complex64> {
1935 self.program()
1936 .evaluate_gradient(amplitude_values, gradient_values)
1937 }
1938}
1939
1940impl From<f64> for Expression {
1941 fn from(value: f64) -> Self {
1942 if value == 0.0 {
1943 Self {
1944 registry: ExpressionRegistry::default(),
1945 tree: ExpressionNode::Zero,
1946 }
1947 } else if value == 1.0 {
1948 Self {
1949 registry: ExpressionRegistry::default(),
1950 tree: ExpressionNode::One,
1951 }
1952 } else {
1953 Self {
1954 registry: ExpressionRegistry::default(),
1955 tree: ExpressionNode::Constant(value),
1956 }
1957 }
1958 }
1959}
1960impl From<&f64> for Expression {
1961 fn from(value: &f64) -> Self {
1962 (*value).into()
1963 }
1964}
1965impl From<Complex64> for Expression {
1966 fn from(value: Complex64) -> Self {
1967 if value == Complex64::ZERO {
1968 Self {
1969 registry: ExpressionRegistry::default(),
1970 tree: ExpressionNode::Zero,
1971 }
1972 } else if value == Complex64::ONE {
1973 Self {
1974 registry: ExpressionRegistry::default(),
1975 tree: ExpressionNode::One,
1976 }
1977 } else {
1978 Self {
1979 registry: ExpressionRegistry::default(),
1980 tree: ExpressionNode::ComplexConstant(value),
1981 }
1982 }
1983 }
1984}
1985impl From<&Complex64> for Expression {
1986 fn from(value: &Complex64) -> Self {
1987 (*value).into()
1988 }
1989}
1990
1991impl Expression {
1992 pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
1994 let registry = ExpressionRegistry::singleton(amplitude)?;
1995 Ok(Self {
1996 tree: ExpressionNode::Amp(0),
1997 registry,
1998 })
1999 }
2000
2001 pub fn zero() -> Self {
2003 Self {
2004 registry: ExpressionRegistry::default(),
2005 tree: ExpressionNode::Zero,
2006 }
2007 }
2008
2009 pub fn one() -> Self {
2011 Self {
2012 registry: ExpressionRegistry::default(),
2013 tree: ExpressionNode::One,
2014 }
2015 }
2016
2017 fn binary_op(
2018 a: &Expression,
2019 b: &Expression,
2020 build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
2021 ) -> Expression {
2022 let (registry, left_map, right_map) = a
2023 .registry
2024 .merge(&b.registry)
2025 .expect("merging expression registries should not fail");
2026 let left_tree = a.tree.remap(&left_map);
2027 let right_tree = b.tree.remap(&right_map);
2028 Expression {
2029 registry,
2030 tree: build(Box::new(left_tree), Box::new(right_tree)),
2031 }
2032 }
2033
2034 fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
2035 Expression {
2036 registry: a.registry.clone(),
2037 tree: build(Box::new(a.tree.clone())),
2038 }
2039 }
2040
2041 pub fn parameters(&self) -> Vec<String> {
2043 self.registry.resources.parameter_names()
2044 }
2045
2046 pub fn free_parameters(&self) -> Vec<String> {
2048 self.registry.resources.free_parameter_names()
2049 }
2050
2051 pub fn fixed_parameters(&self) -> Vec<String> {
2053 self.registry.resources.fixed_parameter_names()
2054 }
2055
2056 pub fn n_free(&self) -> usize {
2058 self.registry.resources.n_free_parameters()
2059 }
2060
2061 pub fn n_fixed(&self) -> usize {
2063 self.registry.resources.n_fixed_parameters()
2064 }
2065
2066 pub fn n_parameters(&self) -> usize {
2068 self.registry.resources.n_parameters()
2069 }
2070
2071 pub fn compiled_expression(&self) -> CompiledExpression {
2077 let active_amplitudes = vec![true; self.registry.amplitudes.len()];
2078 let amplitude_dependencies = self
2079 .registry
2080 .amplitudes
2081 .iter()
2082 .map(|amp| ir::DependenceClass::from(amp.dependence_hint()))
2083 .collect::<Vec<_>>();
2084 let amplitude_realness = self
2085 .registry
2086 .amplitudes
2087 .iter()
2088 .map(|amp| amp.real_valued_hint())
2089 .collect::<Vec<_>>();
2090 let expression_ir = ir::compile_expression_ir_with_real_hints(
2091 &self.tree,
2092 &active_amplitudes,
2093 &litude_dependencies,
2094 &litude_realness,
2095 );
2096 CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
2097 }
2098
2099 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
2101 self.registry.resources.fix_parameter(name, value)
2102 }
2103
2104 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
2106 self.registry.resources.free_parameter(name)
2107 }
2108
2109 pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
2111 self.registry.resources.rename_parameter(old, new)
2112 }
2113
2114 pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
2116 self.registry.resources.rename_parameters(mapping)
2117 }
2118
2119 pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
2121 let mut resources = self.registry.resources.clone();
2122 let metadata = dataset.metadata();
2123 resources.reserve_cache(dataset.n_events_local());
2124 resources.refresh_active_indices();
2125 let parameter_map = resources.parameter_map.clone();
2126 let mut amplitudes: Vec<Box<dyn Amplitude>> = self
2127 .registry
2128 .amplitudes
2129 .iter()
2130 .map(|amp| dyn_clone::clone_box(&**amp))
2131 .collect();
2132 {
2133 for amplitude in amplitudes.iter_mut() {
2134 amplitude.bind(metadata)?;
2135 amplitude.precompute_all(dataset, &mut resources);
2136 }
2137 }
2138 let ir_compile_start = Instant::now();
2139 let expression_ir = {
2140 let mut active_amplitudes = vec![false; amplitudes.len()];
2141 for &index in resources.active_indices() {
2142 active_amplitudes[index] = true;
2143 }
2144 let amplitude_dependencies = amplitudes
2145 .iter()
2146 .map(|amp| ir::DependenceClass::from(amp.dependence_hint()))
2147 .collect::<Vec<_>>();
2148 let amplitude_realness = amplitudes
2149 .iter()
2150 .map(|amp| amp.real_valued_hint())
2151 .collect::<Vec<_>>();
2152 ir::compile_expression_ir_with_real_hints(
2153 &self.tree,
2154 &active_amplitudes,
2155 &litude_dependencies,
2156 &litude_realness,
2157 )
2158 };
2159 let initial_ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
2160 let cached_integrals_start = Instant::now();
2161 let cached_integrals = Evaluator::precompute_cached_integrals_at_load(
2162 &expression_ir,
2163 &litudes,
2164 &resources,
2165 dataset,
2166 parameter_map.free().len(),
2167 )?;
2168 let initial_cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
2169 let lowering_start = Instant::now();
2170 let lowered_artifacts = Arc::new(Evaluator::lower_expression_runtime_artifacts(
2171 &expression_ir,
2172 &cached_integrals,
2173 )?);
2174 let initial_lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
2175 let execution_sets = expression_ir.normalization_execution_sets().clone();
2176 let cached_integral_key =
2177 Evaluator::cached_integral_cache_key(resources.active.clone(), dataset);
2178 let cached_integral_state = Arc::new(CachedIntegralCacheState {
2179 key: cached_integral_key.clone(),
2180 expression_ir,
2181 values: cached_integrals,
2182 execution_sets,
2183 });
2184 let specialization_state = ExpressionSpecializationState {
2185 cached_integrals: cached_integral_state.clone(),
2186 lowered_artifacts: lowered_artifacts.clone(),
2187 };
2188 let specialization_cache = HashMap::from([(cached_integral_key, specialization_state)]);
2189 let lowered_artifact_cache =
2190 HashMap::from([(resources.active.clone(), lowered_artifacts.clone())]);
2191 Ok(Evaluator {
2192 amplitudes,
2193 resources: Arc::new(RwLock::new(resources)),
2194 dataset: dataset.clone(),
2195 expression: self.tree.clone(),
2196 ir_planning: ExpressionIrPlanningState {
2197 cached_integrals: Arc::new(RwLock::new(Some(cached_integral_state))),
2198 specialization_cache: Arc::new(RwLock::new(specialization_cache)),
2199 specialization_metrics: Arc::new(RwLock::new(ExpressionSpecializationMetrics {
2200 cache_hits: 0,
2201 cache_misses: 1,
2202 })),
2203 lowered_artifact_cache: Arc::new(RwLock::new(lowered_artifact_cache)),
2204 active_lowered_artifacts: Arc::new(RwLock::new(Some(lowered_artifacts.clone()))),
2205 specialization_status: Arc::new(RwLock::new(Some(
2206 ExpressionSpecializationStatus {
2207 origin: ExpressionSpecializationOrigin::InitialLoad,
2208 },
2209 ))),
2210 compile_metrics: Arc::new(RwLock::new(ExpressionCompileMetrics {
2211 initial_ir_compile_nanos,
2212 initial_cached_integrals_nanos,
2213 initial_lowering_nanos,
2214 specialization_lowering_cache_misses: 1,
2215 ..Default::default()
2216 })),
2217 },
2218 registry: self.registry.clone(),
2219 })
2220 }
2221
2222 pub fn real(&self) -> Self {
2224 Self::unary_op(self, ExpressionNode::Real)
2225 }
2226 pub fn imag(&self) -> Self {
2228 Self::unary_op(self, ExpressionNode::Imag)
2229 }
2230 pub fn conj(&self) -> Self {
2232 Self::unary_op(self, ExpressionNode::Conj)
2233 }
2234 pub fn norm_sqr(&self) -> Self {
2236 Self::unary_op(self, ExpressionNode::NormSqr)
2237 }
2238 pub fn sqrt(&self) -> Self {
2240 Self::unary_op(self, ExpressionNode::Sqrt)
2241 }
2242 pub fn pow(&self, power: &Expression) -> Self {
2244 Self::binary_op(self, power, ExpressionNode::Pow)
2245 }
2246 pub fn powi(&self, power: i32) -> Self {
2248 Self::unary_op(self, |input| ExpressionNode::PowI(input, power))
2249 }
2250 pub fn powf(&self, power: f64) -> Self {
2252 Self::unary_op(self, |input| ExpressionNode::PowF(input, power))
2253 }
2254 pub fn exp(&self) -> Self {
2256 Self::unary_op(self, ExpressionNode::Exp)
2257 }
2258 pub fn sin(&self) -> Self {
2260 Self::unary_op(self, ExpressionNode::Sin)
2261 }
2262 pub fn cos(&self) -> Self {
2264 Self::unary_op(self, ExpressionNode::Cos)
2265 }
2266 pub fn log(&self) -> Self {
2268 Self::unary_op(self, ExpressionNode::Log)
2269 }
2270 pub fn cis(&self) -> Self {
2272 Self::unary_op(self, ExpressionNode::Cis)
2273 }
2274
2275 fn write_tree(
2277 &self,
2278 t: &ExpressionNode,
2279 f: &mut std::fmt::Formatter<'_>,
2280 parent_prefix: &str,
2281 immediate_prefix: &str,
2282 parent_suffix: &str,
2283 ) -> std::fmt::Result {
2284 let display_string = match t {
2285 ExpressionNode::Amp(idx) => {
2286 let name = self
2287 .registry
2288 .amplitude_names
2289 .get(*idx)
2290 .cloned()
2291 .unwrap_or_else(|| "<unregistered>".to_string());
2292 format!("{name}(id={idx})")
2293 }
2294 ExpressionNode::Add(_, _) => "+".to_string(),
2295 ExpressionNode::Sub(_, _) => "-".to_string(),
2296 ExpressionNode::Mul(_, _) => "×".to_string(),
2297 ExpressionNode::Div(_, _) => "÷".to_string(),
2298 ExpressionNode::Neg(_) => "-".to_string(),
2299 ExpressionNode::Real(_) => "Re".to_string(),
2300 ExpressionNode::Imag(_) => "Im".to_string(),
2301 ExpressionNode::Conj(_) => "*".to_string(),
2302 ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
2303 ExpressionNode::Zero => "0 (exact)".to_string(),
2304 ExpressionNode::One => "1 (exact)".to_string(),
2305 ExpressionNode::Constant(v) => v.to_string(),
2306 ExpressionNode::ComplexConstant(v) => v.to_string(),
2307 ExpressionNode::Sqrt(_) => "Sqrt".to_string(),
2308 ExpressionNode::Pow(_, _) => "Pow".to_string(),
2309 ExpressionNode::PowI(_, power) => format!("PowI({power})"),
2310 ExpressionNode::PowF(_, power) => format!("PowF({power})"),
2311 ExpressionNode::Exp(_) => "Exp".to_string(),
2312 ExpressionNode::Sin(_) => "Sin".to_string(),
2313 ExpressionNode::Cos(_) => "Cos".to_string(),
2314 ExpressionNode::Log(_) => "Log".to_string(),
2315 ExpressionNode::Cis(_) => "Cis".to_string(),
2316 };
2317 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
2318 match t {
2319 ExpressionNode::Amp(_)
2320 | ExpressionNode::Zero
2321 | ExpressionNode::One
2322 | ExpressionNode::Constant(_)
2323 | ExpressionNode::ComplexConstant(_) => {}
2324 ExpressionNode::Add(a, b)
2325 | ExpressionNode::Sub(a, b)
2326 | ExpressionNode::Mul(a, b)
2327 | ExpressionNode::Div(a, b)
2328 | ExpressionNode::Pow(a, b) => {
2329 let terms = [a, b];
2330 let mut it = terms.iter().peekable();
2331 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
2332 while let Some(child) = it.next() {
2333 match it.peek() {
2334 Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│ "),
2335 None => self.write_tree(child, f, &child_prefix, "└─ ", " "),
2336 }?;
2337 }
2338 }
2339 ExpressionNode::Neg(a)
2340 | ExpressionNode::Real(a)
2341 | ExpressionNode::Imag(a)
2342 | ExpressionNode::Conj(a)
2343 | ExpressionNode::NormSqr(a)
2344 | ExpressionNode::Sqrt(a)
2345 | ExpressionNode::PowI(a, _)
2346 | ExpressionNode::PowF(a, _)
2347 | ExpressionNode::Exp(a)
2348 | ExpressionNode::Sin(a)
2349 | ExpressionNode::Cos(a)
2350 | ExpressionNode::Log(a)
2351 | ExpressionNode::Cis(a) => {
2352 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
2353 self.write_tree(a, f, &child_prefix, "└─ ", " ")?;
2354 }
2355 }
2356 Ok(())
2357 }
2358}
2359
2360impl Debug for Expression {
2361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2362 self.write_tree(&self.tree, f, "", "", "")
2363 }
2364}
2365
2366impl Display for Expression {
2367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2368 self.write_tree(&self.tree, f, "", "", "")
2369 }
2370}
2371
2372#[rustfmt::skip]
2373impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
2374 Expression::binary_op(a, b, ExpressionNode::Add)
2375});
2376#[rustfmt::skip]
2377impl_op_ex!(+ |a: &Expression, b: &f64| -> Expression {
2378 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
2379});
2380#[rustfmt::skip]
2381impl_op_ex!(+ |a: &f64, b: &Expression| -> Expression {
2382 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
2383});
2384#[rustfmt::skip]
2385impl_op_ex!(+ |a: &Expression, b: &Complex64| -> Expression {
2386 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Add)
2387});
2388#[rustfmt::skip]
2389impl_op_ex!(+ |a: &Complex64, b: &Expression| -> Expression {
2390 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Add)
2391});
2392
2393#[rustfmt::skip]
2394impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
2395 Expression::binary_op(a, b, ExpressionNode::Sub)
2396});
2397#[rustfmt::skip]
2398impl_op_ex!(- |a: &Expression, b: &f64| -> Expression {
2399 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
2400});
2401#[rustfmt::skip]
2402impl_op_ex!(- |a: &f64, b: &Expression| -> Expression {
2403 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
2404});
2405#[rustfmt::skip]
2406impl_op_ex!(- |a: &Expression, b: &Complex64| -> Expression {
2407 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Sub)
2408});
2409#[rustfmt::skip]
2410impl_op_ex!(- |a: &Complex64, b: &Expression| -> Expression {
2411 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Sub)
2412});
2413
2414#[rustfmt::skip]
2415impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
2416 Expression::binary_op(a, b, ExpressionNode::Mul)
2417});
2418#[rustfmt::skip]
2419impl_op_ex!(* |a: &Expression, b: &f64| -> Expression {
2420 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
2421});
2422#[rustfmt::skip]
2423impl_op_ex!(* |a: &f64, b: &Expression| -> Expression {
2424 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
2425});
2426#[rustfmt::skip]
2427impl_op_ex!(* |a: &Expression, b: &Complex64| -> Expression {
2428 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Mul)
2429});
2430#[rustfmt::skip]
2431impl_op_ex!(* |a: &Complex64, b: &Expression| -> Expression {
2432 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Mul)
2433});
2434
2435#[rustfmt::skip]
2436impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
2437 Expression::binary_op(a, b, ExpressionNode::Div)
2438});
2439#[rustfmt::skip]
2440impl_op_ex!(/ |a: &Expression, b: &f64| -> Expression {
2441 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
2442});
2443#[rustfmt::skip]
2444impl_op_ex!(/ |a: &f64, b: &Expression| -> Expression {
2445 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
2446});
2447#[rustfmt::skip]
2448impl_op_ex!(/ |a: &Expression, b: &Complex64| -> Expression {
2449 Expression::binary_op(a, &Expression::from(b), ExpressionNode::Div)
2450});
2451#[rustfmt::skip]
2452impl_op_ex!(/ |a: &Complex64, b: &Expression| -> Expression {
2453 Expression::binary_op(&Expression::from(a), b, ExpressionNode::Div)
2454});
2455
2456#[rustfmt::skip]
2457impl_op_ex!(- |a: &Expression| -> Expression {
2458 Expression::unary_op(a, ExpressionNode::Neg)
2459});
2460#[derive(Clone, Debug)]
2463#[doc(hidden)]
2464pub struct ExpressionValueProgramSnapshot {
2465 lowered_program: lowered::LoweredProgram,
2466}
2467
2468#[derive(Clone, Debug, PartialEq)]
2469pub enum CompiledExpressionNode {
2471 Constant(Complex64),
2473 Amplitude {
2475 index: usize,
2477 name: String,
2479 },
2480 Unary {
2482 op: String,
2484 input: usize,
2486 },
2487 Binary {
2489 op: String,
2491 left: usize,
2493 right: usize,
2495 },
2496}
2497
2498#[derive(Clone, Debug, PartialEq)]
2499pub struct CompiledExpression {
2505 nodes: Vec<CompiledExpressionNode>,
2506 root: usize,
2507}
2508
2509impl CompiledExpression {
2510 fn from_ir(ir: &ir::ExpressionIR, amplitude_names: &[String]) -> Self {
2511 let nodes = ir
2512 .nodes()
2513 .iter()
2514 .map(|node| match node {
2515 ir::IrNode::Constant(value) => CompiledExpressionNode::Constant(*value),
2516 ir::IrNode::Amp(index) => CompiledExpressionNode::Amplitude {
2517 index: *index,
2518 name: amplitude_names
2519 .get(*index)
2520 .cloned()
2521 .unwrap_or_else(|| "<unregistered>".to_string()),
2522 },
2523 ir::IrNode::Unary { op, input } => CompiledExpressionNode::Unary {
2524 op: compiled_unary_op_label(*op),
2525 input: *input,
2526 },
2527 ir::IrNode::Binary { op, left, right } => CompiledExpressionNode::Binary {
2528 op: compiled_binary_op_label(*op),
2529 left: *left,
2530 right: *right,
2531 },
2532 })
2533 .collect();
2534 Self {
2535 nodes,
2536 root: ir.root(),
2537 }
2538 }
2539
2540 pub fn nodes(&self) -> &[CompiledExpressionNode] {
2542 &self.nodes
2543 }
2544
2545 pub fn root(&self) -> usize {
2547 self.root
2548 }
2549
2550 fn node_label(&self, index: usize) -> String {
2551 let Some(node) = self.nodes.get(index) else {
2552 return format!("#{index} <missing>");
2553 };
2554 let label = match node {
2555 CompiledExpressionNode::Constant(value) => format!("const {value}"),
2556 CompiledExpressionNode::Amplitude { index, name } => {
2557 format!("{name}(id={index})")
2558 }
2559 CompiledExpressionNode::Unary { op, .. }
2560 | CompiledExpressionNode::Binary { op, .. } => op.clone(),
2561 };
2562 format!("#{index} {label}")
2563 }
2564
2565 fn write_tree(
2567 &self,
2568 index: usize,
2569 f: &mut std::fmt::Formatter<'_>,
2570 parent_prefix: &str,
2571 immediate_prefix: &str,
2572 parent_suffix: &str,
2573 expanded: &mut [bool],
2574 ) -> std::fmt::Result {
2575 let already_expanded = expanded.get(index).copied().unwrap_or(false);
2576 if let Some(slot) = expanded.get_mut(index) {
2577 *slot = true;
2578 }
2579 let ref_suffix = if already_expanded { " (ref)" } else { "" };
2580 writeln!(
2581 f,
2582 "{}{}{}{}",
2583 parent_prefix,
2584 immediate_prefix,
2585 self.node_label(index),
2586 ref_suffix
2587 )?;
2588 if already_expanded {
2589 return Ok(());
2590 }
2591 let Some(node) = self.nodes.get(index) else {
2592 return Ok(());
2593 };
2594 match node {
2595 CompiledExpressionNode::Constant(_) | CompiledExpressionNode::Amplitude { .. } => {}
2596 CompiledExpressionNode::Unary { input, .. } => {
2597 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
2598 self.write_tree(*input, f, &child_prefix, "└─ ", " ", expanded)?;
2599 }
2600 CompiledExpressionNode::Binary { left, right, .. } => {
2601 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
2602 self.write_tree(*left, f, &child_prefix, "├─ ", "│ ", expanded)?;
2603 self.write_tree(*right, f, &child_prefix, "└─ ", " ", expanded)?;
2604 }
2605 }
2606 Ok(())
2607 }
2608}
2609
2610impl Display for CompiledExpression {
2611 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2612 if self.nodes.is_empty() {
2613 return writeln!(f, "<empty>");
2614 }
2615 let mut expanded = vec![false; self.nodes.len()];
2616 self.write_tree(self.root, f, "", "", "", &mut expanded)
2617 }
2618}
2619
2620fn compiled_unary_op_label(op: ir::IrUnaryOp) -> String {
2621 match op {
2622 ir::IrUnaryOp::Neg => "-".to_string(),
2623 ir::IrUnaryOp::Real => "Re".to_string(),
2624 ir::IrUnaryOp::Imag => "Im".to_string(),
2625 ir::IrUnaryOp::Conj => "*".to_string(),
2626 ir::IrUnaryOp::NormSqr => "NormSqr".to_string(),
2627 ir::IrUnaryOp::Sqrt => "Sqrt".to_string(),
2628 ir::IrUnaryOp::PowI(power) => format!("PowI({power})"),
2629 ir::IrUnaryOp::PowF(bits) => format!("PowF({})", f64::from_bits(bits)),
2630 ir::IrUnaryOp::Exp => "Exp".to_string(),
2631 ir::IrUnaryOp::Sin => "Sin".to_string(),
2632 ir::IrUnaryOp::Cos => "Cos".to_string(),
2633 ir::IrUnaryOp::Log => "Log".to_string(),
2634 ir::IrUnaryOp::Cis => "Cis".to_string(),
2635 }
2636}
2637
2638fn compiled_binary_op_label(op: ir::IrBinaryOp) -> String {
2639 match op {
2640 ir::IrBinaryOp::Add => "+".to_string(),
2641 ir::IrBinaryOp::Sub => "-".to_string(),
2642 ir::IrBinaryOp::Mul => "×".to_string(),
2643 ir::IrBinaryOp::Div => "÷".to_string(),
2644 ir::IrBinaryOp::Pow => "Pow".to_string(),
2645 }
2646}
2647#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2648pub enum ExpressionSpecializationOrigin {
2650 InitialLoad,
2652 CacheMissRebuild,
2654 CacheHitRestore,
2656}
2657#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2658pub struct ExpressionSpecializationStatus {
2660 pub origin: ExpressionSpecializationOrigin,
2662}
2663#[derive(Clone, Debug, PartialEq, Eq)]
2664pub struct ExpressionRuntimeDiagnostics {
2666 pub ir_planning_enabled: bool,
2668 pub lowered_value_program_present: bool,
2670 pub lowered_gradient_program_present: bool,
2672 pub lowered_value_gradient_program_present: bool,
2674 pub cached_parameter_factor_count: usize,
2676 pub lowered_cached_parameter_factor_count: usize,
2678 pub residual_runtime_present: bool,
2680 pub specialization_cache_entries: usize,
2682 pub lowered_artifact_cache_entries: usize,
2684 pub specialization_status: Option<ExpressionSpecializationStatus>,
2686}
2687#[derive(Clone)]
2688struct ExpressionIrPlanningState {
2695 cached_integrals: Arc<RwLock<Option<Arc<CachedIntegralCacheState>>>>,
2696 specialization_cache:
2697 Arc<RwLock<HashMap<CachedIntegralCacheKey, ExpressionSpecializationState>>>,
2698 specialization_metrics: Arc<RwLock<ExpressionSpecializationMetrics>>,
2699 lowered_artifact_cache: Arc<RwLock<HashMap<Vec<bool>, Arc<LoweredArtifactCacheState>>>>,
2700 active_lowered_artifacts: Arc<RwLock<Option<Arc<LoweredArtifactCacheState>>>>,
2701 specialization_status: Arc<RwLock<Option<ExpressionSpecializationStatus>>>,
2702 compile_metrics: Arc<RwLock<ExpressionCompileMetrics>>,
2703}
2704#[allow(missing_docs)]
2706#[derive(Clone)]
2707pub struct Evaluator {
2708 pub amplitudes: Vec<Box<dyn Amplitude>>,
2709 pub resources: Arc<RwLock<Resources>>,
2710 pub dataset: Arc<Dataset>,
2711 pub expression: ExpressionNode,
2712 ir_planning: ExpressionIrPlanningState,
2713 registry: ExpressionRegistry,
2714}
2715
2716#[allow(missing_docs)]
2717impl Evaluator {
2718 pub fn expression_specialization_metrics(&self) -> ExpressionSpecializationMetrics {
2720 *self.ir_planning.specialization_metrics.read()
2721 }
2722 pub fn reset_expression_specialization_metrics(&self) {
2724 *self.ir_planning.specialization_metrics.write() =
2725 ExpressionSpecializationMetrics::default();
2726 }
2727 pub fn expression_compile_metrics(&self) -> ExpressionCompileMetrics {
2729 *self.ir_planning.compile_metrics.read()
2730 }
2731 pub fn expression_runtime_diagnostics(&self) -> ExpressionRuntimeDiagnostics {
2733 let active_artifacts = self.active_lowered_artifacts();
2734 let cached_parameter_factor_count = self
2735 .ir_planning
2736 .cached_integrals
2737 .read()
2738 .as_ref()
2739 .map(|state| state.values.len())
2740 .unwrap_or(0);
2741 let lowered_cached_parameter_factor_count = active_artifacts
2742 .as_ref()
2743 .map(|artifacts| {
2744 artifacts
2745 .lowered_parameter_factors
2746 .iter()
2747 .filter(|factor| factor.is_some())
2748 .count()
2749 })
2750 .unwrap_or(0);
2751 let residual_runtime_present = active_artifacts
2752 .as_ref()
2753 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
2754 .is_some();
2755 ExpressionRuntimeDiagnostics {
2756 ir_planning_enabled: true,
2757 lowered_value_program_present: true,
2758 lowered_gradient_program_present: true,
2759 lowered_value_gradient_program_present: true,
2760 cached_parameter_factor_count,
2761 lowered_cached_parameter_factor_count,
2762 residual_runtime_present,
2763 specialization_cache_entries: self.ir_planning.specialization_cache.read().len(),
2764 lowered_artifact_cache_entries: self.ir_planning.lowered_artifact_cache.read().len(),
2765 specialization_status: *self.ir_planning.specialization_status.read(),
2766 }
2767 }
2768 pub fn reset_expression_compile_metrics(&self) {
2770 let mut metrics = self.ir_planning.compile_metrics.write();
2771 metrics.specialization_cache_hits = 0;
2772 metrics.specialization_cache_misses = 0;
2773 metrics.specialization_ir_compile_nanos = 0;
2774 metrics.specialization_cached_integrals_nanos = 0;
2775 metrics.specialization_lowering_nanos = 0;
2776 metrics.specialization_lowering_cache_hits = 0;
2777 metrics.specialization_lowering_cache_misses = 0;
2778 metrics.specialization_cache_restore_nanos = 0;
2779 }
2780 #[cfg(test)]
2781 fn expression_ir(&self) -> ir::ExpressionIR {
2782 self.ir_planning
2783 .cached_integrals
2784 .read()
2785 .as_ref()
2786 .map(|state| state.expression_ir.clone())
2787 .expect("cached integral state should exist for evaluator IR access")
2788 }
2789 fn lowered_runtime(&self) -> lowered::LoweredExpressionRuntime {
2790 self.active_lowered_artifacts()
2791 .expect("active lowered artifacts should exist for the current specialization")
2792 .lowered_runtime
2793 .clone()
2794 }
2795 fn active_lowered_artifacts(&self) -> Option<Arc<LoweredArtifactCacheState>> {
2796 self.ir_planning.active_lowered_artifacts.read().clone()
2797 }
2798 fn lowered_runtime_slot_count(&self) -> usize {
2799 let runtime = self.lowered_runtime();
2800 [
2801 runtime.value_program().scratch_slots(),
2802 runtime.gradient_program().scratch_slots(),
2803 runtime.value_gradient_program().scratch_slots(),
2804 ]
2805 .into_iter()
2806 .max()
2807 .unwrap_or(0)
2808 }
2809 fn lowered_value_runtime_slot_count(&self) -> usize {
2810 self.lowered_runtime().value_program().scratch_slots()
2811 }
2812
2813 #[doc(hidden)]
2814 pub fn expression_value_program_snapshot(&self) -> ExpressionValueProgramSnapshot {
2815 ExpressionValueProgramSnapshot {
2816 lowered_program: self.lowered_runtime().value_program().clone(),
2817 }
2818 }
2819
2820 #[doc(hidden)]
2821 pub fn expression_value_program_snapshot_for_active_mask(
2822 &self,
2823 active_mask: &[bool],
2824 ) -> LadduResult<ExpressionValueProgramSnapshot> {
2825 let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
2826 let lowered_program =
2827 lowered::LoweredProgram::from_ir_value_only(&expression_ir).map_err(|error| {
2828 LadduError::Custom(format!(
2829 "Failed to lower value-only active-mask runtime: {error:?}"
2830 ))
2831 })?;
2832 Ok(ExpressionValueProgramSnapshot { lowered_program })
2833 }
2834
2835 #[doc(hidden)]
2836 pub fn expression_value_program_snapshot_slot_count(
2837 &self,
2838 snapshot: &ExpressionValueProgramSnapshot,
2839 ) -> usize {
2840 let _ = self;
2841 snapshot.lowered_program.scratch_slots()
2842 }
2843
2844 pub fn compiled_expression(&self) -> CompiledExpression {
2847 let expression_ir = self.compile_expression_ir_for_active_mask(&self.active_mask());
2848 CompiledExpression::from_ir(&expression_ir, &self.registry.amplitude_names)
2849 }
2850
2851 pub fn expression(&self) -> Expression {
2853 Expression {
2854 tree: self.expression.clone(),
2855 registry: self.registry.clone(),
2856 }
2857 }
2858 fn lowered_gradient_runtime_slot_count(&self) -> usize {
2859 self.lowered_runtime().gradient_program().scratch_slots()
2860 }
2861 fn lowered_value_gradient_runtime_slot_count(&self) -> usize {
2862 self.lowered_runtime()
2863 .value_gradient_program()
2864 .scratch_slots()
2865 }
2866
2867 fn expression_value_slot_count(&self) -> usize {
2868 self.lowered_value_runtime_slot_count()
2869 }
2870 fn expression_gradient_slot_count(&self) -> usize {
2871 self.lowered_gradient_runtime_slot_count()
2872 }
2873 fn expression_value_gradient_slot_count(&self) -> usize {
2874 self.lowered_value_gradient_runtime_slot_count()
2875 }
2876
2877 #[doc(hidden)]
2878 pub fn expression_value_gradient_slot_count_public(&self) -> usize {
2879 self.expression_value_gradient_slot_count()
2880 }
2881 #[cfg(test)]
2882 fn specialization_cache_len(&self) -> usize {
2883 self.ir_planning.specialization_cache.read().len()
2884 }
2885 #[cfg(test)]
2886 fn lowered_artifact_cache_len(&self) -> usize {
2887 self.ir_planning.lowered_artifact_cache.read().len()
2888 }
2889 fn install_expression_specialization(&self, specialization: &ExpressionSpecializationState) {
2890 debug_assert!(Self::lowered_artifact_signature_matches(
2891 &specialization.lowered_artifacts,
2892 &specialization.cached_integrals.values,
2893 ));
2894 *self.ir_planning.cached_integrals.write() = Some(specialization.cached_integrals.clone());
2895 *self.ir_planning.active_lowered_artifacts.write() =
2896 Some(specialization.lowered_artifacts.clone());
2897 debug_assert_eq!(
2898 self.active_lowered_artifacts()
2899 .as_ref()
2900 .map(|artifacts| Arc::ptr_eq(artifacts, &specialization.lowered_artifacts)),
2901 Some(true)
2902 );
2903 debug_assert_eq!(
2904 self.lowered_runtime().value_program().scratch_slots(),
2905 specialization
2906 .lowered_artifacts
2907 .lowered_runtime
2908 .value_program()
2909 .scratch_slots()
2910 );
2911 }
2912 fn lower_expression_runtime_artifacts(
2913 expression_ir: &ir::ExpressionIR,
2914 values: &[PrecomputedCachedIntegral],
2915 ) -> LadduResult<LoweredArtifactCacheState> {
2916 let parameter_node_indices = values
2917 .iter()
2918 .map(|value| value.parameter_node_index)
2919 .collect();
2920 let mul_node_indices = values.iter().map(|value| value.mul_node_index).collect();
2921 let lowered_parameter_factors = Self::lower_cached_parameter_factors(expression_ir);
2922 let residual_runtime = Self::lower_residual_runtime(expression_ir, values);
2923 let lowered_runtime = lowered::LoweredExpressionRuntime::from_ir_value_gradient(
2924 expression_ir,
2925 )
2926 .map_err(|error| {
2927 LadduError::Custom(format!(
2928 "Failed to lower expression runtime for specialized IR: {error:?}"
2929 ))
2930 })?;
2931 Ok(LoweredArtifactCacheState {
2932 parameter_node_indices,
2933 mul_node_indices,
2934 lowered_parameter_factors,
2935 residual_runtime,
2936 lowered_runtime,
2937 })
2938 }
2939 fn lowered_artifact_signature_matches(
2940 artifacts: &LoweredArtifactCacheState,
2941 values: &[PrecomputedCachedIntegral],
2942 ) -> bool {
2943 artifacts.parameter_node_indices.len() == values.len()
2944 && artifacts.mul_node_indices.len() == values.len()
2945 && artifacts
2946 .parameter_node_indices
2947 .iter()
2948 .copied()
2949 .eq(values.iter().map(|value| value.parameter_node_index))
2950 && artifacts
2951 .mul_node_indices
2952 .iter()
2953 .copied()
2954 .eq(values.iter().map(|value| value.mul_node_index))
2955 }
2956 fn build_expression_specialization(
2957 &self,
2958 resources: &Resources,
2959 key: CachedIntegralCacheKey,
2960 ) -> LadduResult<ExpressionSpecializationState> {
2961 let ir_compile_start = Instant::now();
2962 let expression_ir = self.compile_expression_ir_for_active_mask(&resources.active);
2963 let ir_compile_nanos = ir_compile_start.elapsed().as_nanos() as u64;
2964 let cached_integrals_start = Instant::now();
2965 let values = Self::precompute_cached_integrals_at_load(
2966 &expression_ir,
2967 &self.amplitudes,
2968 resources,
2969 &self.dataset,
2970 self.resources.read().n_free_parameters(),
2971 )?;
2972 let cached_integrals_nanos = cached_integrals_start.elapsed().as_nanos() as u64;
2973 let execution_sets = expression_ir.normalization_execution_sets().clone();
2974 let active_mask_key = resources.active.clone();
2975 let cached_lowered_artifacts = {
2976 let lowered_artifact_cache = self.ir_planning.lowered_artifact_cache.read();
2977 lowered_artifact_cache
2978 .get(&active_mask_key)
2979 .cloned()
2980 .filter(|artifacts| Self::lowered_artifact_signature_matches(artifacts, &values))
2981 };
2982 let lowered_artifacts = if let Some(artifacts) = cached_lowered_artifacts {
2983 self.ir_planning
2984 .compile_metrics
2985 .write()
2986 .specialization_lowering_cache_hits += 1;
2987 artifacts
2988 } else {
2989 let lowering_start = Instant::now();
2990 let artifacts = Arc::new(
2991 Self::lower_expression_runtime_artifacts(&expression_ir, &values)
2992 .expect("specialized lowered runtime should build"),
2993 );
2994 let lowering_nanos = lowering_start.elapsed().as_nanos() as u64;
2995 self.ir_planning
2996 .lowered_artifact_cache
2997 .write()
2998 .insert(active_mask_key, artifacts.clone());
2999 let mut compile_metrics = self.ir_planning.compile_metrics.write();
3000 compile_metrics.specialization_lowering_cache_misses += 1;
3001 compile_metrics.specialization_lowering_nanos += lowering_nanos;
3002 artifacts
3003 };
3004 let mut compile_metrics = self.ir_planning.compile_metrics.write();
3005 compile_metrics.specialization_cache_misses += 1;
3006 compile_metrics.specialization_ir_compile_nanos += ir_compile_nanos;
3007 compile_metrics.specialization_cached_integrals_nanos += cached_integrals_nanos;
3008 Ok(ExpressionSpecializationState {
3009 cached_integrals: Arc::new(CachedIntegralCacheState {
3010 key,
3011 expression_ir,
3012 values,
3013 execution_sets,
3014 }),
3015 lowered_artifacts,
3016 })
3017 }
3018 fn ensure_expression_specialization(
3019 &self,
3020 resources: &Resources,
3021 ) -> LadduResult<ExpressionSpecializationState> {
3022 let key = Self::cached_integral_cache_key(resources.active.clone(), &self.dataset);
3023 if let Some(state) = self.ir_planning.cached_integrals.read().as_ref() {
3024 if state.key == key {
3025 return Ok(ExpressionSpecializationState {
3026 cached_integrals: state.clone(),
3027 lowered_artifacts: self
3028 .active_lowered_artifacts()
3029 .expect("active lowered artifacts should exist for cached specialization"),
3030 });
3031 }
3032 }
3033 let cached_specialization = {
3034 let specialization_cache = self.ir_planning.specialization_cache.read();
3035 specialization_cache.get(&key).cloned()
3036 };
3037 if let Some(specialization) = cached_specialization {
3038 let restore_start = Instant::now();
3039 self.ir_planning.specialization_metrics.write().cache_hits += 1;
3040 self.install_expression_specialization(&specialization);
3041 *self.ir_planning.specialization_status.write() =
3042 Some(ExpressionSpecializationStatus {
3043 origin: ExpressionSpecializationOrigin::CacheHitRestore,
3044 });
3045 let restore_nanos = restore_start.elapsed().as_nanos() as u64;
3046 let mut compile_metrics = self.ir_planning.compile_metrics.write();
3047 compile_metrics.specialization_cache_hits += 1;
3048 compile_metrics.specialization_cache_restore_nanos += restore_nanos;
3049 return Ok(specialization);
3050 }
3051 let specialization = self.build_expression_specialization(resources, key.clone())?;
3052 self.ir_planning.specialization_metrics.write().cache_misses += 1;
3053 self.ir_planning
3054 .specialization_cache
3055 .write()
3056 .insert(key, specialization.clone());
3057 self.install_expression_specialization(&specialization);
3058 let origin = if self.ir_planning.specialization_cache.read().len() == 1 {
3059 ExpressionSpecializationOrigin::InitialLoad
3060 } else {
3061 ExpressionSpecializationOrigin::CacheMissRebuild
3062 };
3063 *self.ir_planning.specialization_status.write() =
3064 Some(ExpressionSpecializationStatus { origin });
3065 Ok(specialization)
3066 }
3067 fn rebuild_runtime_specializations(&self, resources: &Resources) {
3068 let _ = self.ensure_expression_specialization(resources);
3069 }
3070 fn refresh_runtime_specializations(&self) {
3071 let resources = self.resources.read();
3072 self.rebuild_runtime_specializations(&resources);
3073 }
3074 fn cached_integral_cache_key(
3075 active_mask: Vec<bool>,
3076 dataset: &Dataset,
3077 ) -> CachedIntegralCacheKey {
3078 CachedIntegralCacheKey {
3079 active_mask,
3080 n_events_local: dataset.n_events_local(),
3081 events_local_len: dataset.events_local().len(),
3082 weighted_sum_bits: dataset.n_events_weighted_local().to_bits(),
3083 events_ptr: dataset.events_local().as_ptr() as usize,
3084 }
3085 }
3086 fn precompute_cached_integrals_at_load(
3087 expression_ir: &ir::ExpressionIR,
3088 amplitudes: &[Box<dyn Amplitude>],
3089 resources: &Resources,
3090 dataset: &Dataset,
3091 n_free_parameters: usize,
3092 ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
3093 let descriptors = expression_ir.cached_integral_descriptors();
3094 if descriptors.is_empty() {
3095 return Ok(Vec::new());
3096 }
3097 let execution_sets = expression_ir.normalization_execution_sets();
3098 let seed_parameters = vec![0.0; n_free_parameters];
3099 let parameters = resources.parameter_map.assemble(&seed_parameters)?;
3100 let mut amplitude_values = vec![Complex64::ZERO; amplitudes.len()];
3101 let mut value_slots = vec![Complex64::ZERO; expression_ir.node_count()];
3102 let active_set = resources.active_indices();
3103 let cache_active_indices = execution_sets
3104 .cached_cache_amplitudes
3105 .iter()
3106 .copied()
3107 .filter(|index| active_set.binary_search(index).is_ok())
3108 .collect::<Vec<_>>();
3109 let mut weighted_cache_sums = vec![Complex64::ZERO; descriptors.len()];
3110 for (cache, event) in resources.caches.iter().zip(dataset.events_local().iter()) {
3111 amplitude_values.fill(Complex64::ZERO);
3112 for &_idx in &cache_active_indices {
3113 amplitude_values[amp_idx] = amplitudes[amp_idx].compute(¶meters, cache);
3114 }
3115 expression_ir.evaluate_into(&litude_values, &mut value_slots);
3116 let weight = event.weight();
3117 for (descriptor_index, descriptor) in descriptors.iter().enumerate() {
3118 weighted_cache_sums[descriptor_index] +=
3119 value_slots[descriptor.cache_node_index] * weight;
3120 }
3121 }
3122 Ok(descriptors
3123 .iter()
3124 .zip(weighted_cache_sums)
3125 .map(
3126 |(descriptor, weighted_cache_sum)| PrecomputedCachedIntegral {
3127 mul_node_index: descriptor.mul_node_index,
3128 parameter_node_index: descriptor.parameter_node_index,
3129 cache_node_index: descriptor.cache_node_index,
3130 coefficient: descriptor.coefficient,
3131 weighted_cache_sum,
3132 },
3133 )
3134 .collect())
3135 }
3136 fn lower_cached_parameter_factors(
3137 expression_ir: &ir::ExpressionIR,
3138 ) -> Vec<Option<lowered::LoweredFactorRuntime>> {
3139 expression_ir
3140 .cached_integral_descriptors()
3141 .iter()
3142 .map(|descriptor| {
3143 lowered::LoweredFactorRuntime::from_ir_root_value_gradient(
3144 expression_ir,
3145 descriptor.parameter_node_index,
3146 )
3147 .ok()
3148 })
3149 .collect()
3150 }
3151 fn lower_residual_runtime(
3152 expression_ir: &ir::ExpressionIR,
3153 descriptors: &[PrecomputedCachedIntegral],
3154 ) -> Option<lowered::LoweredExpressionRuntime> {
3155 let mut zeroed_nodes = vec![false; expression_ir.node_count()];
3156 for descriptor in descriptors {
3157 if descriptor.mul_node_index < zeroed_nodes.len() {
3158 zeroed_nodes[descriptor.mul_node_index] = true;
3159 }
3160 }
3161 lowered::LoweredExpressionRuntime::from_ir_zeroed_value_gradient(
3162 expression_ir,
3163 &zeroed_nodes,
3164 )
3165 .ok()
3166 }
3167
3168 #[inline]
3169 fn fill_amplitude_values(
3170 &self,
3171 amplitude_values: &mut [Complex64],
3172 active_indices: &[usize],
3173 parameters: &Parameters,
3174 cache: &Cache,
3175 ) {
3176 amplitude_values.fill(Complex64::ZERO);
3177 for &_idx in active_indices {
3178 amplitude_values[amp_idx] = self.amplitudes[amp_idx].compute(parameters, cache);
3179 }
3180 }
3181
3182 #[inline]
3183 fn fill_amplitude_gradients(
3184 &self,
3185 gradient_values: &mut [DVector<Complex64>],
3186 active_mask: &[bool],
3187 parameters: &Parameters,
3188 cache: &Cache,
3189 ) {
3190 for ((amp, active), grad) in self
3191 .amplitudes
3192 .iter()
3193 .zip(active_mask.iter())
3194 .zip(gradient_values.iter_mut())
3195 {
3196 grad.fill(Complex64::ZERO);
3197 if *active {
3198 amp.compute_gradient(parameters, cache, grad);
3199 }
3200 }
3201 }
3202
3203 #[inline]
3204 fn fill_amplitude_values_and_gradients(
3205 &self,
3206 amplitude_values: &mut [Complex64],
3207 gradient_values: &mut [DVector<Complex64>],
3208 active_indices: &[usize],
3209 active_mask: &[bool],
3210 parameters: &Parameters,
3211 cache: &Cache,
3212 ) {
3213 self.fill_amplitude_values(amplitude_values, active_indices, parameters, cache);
3214 self.fill_amplitude_gradients(gradient_values, active_mask, parameters, cache);
3215 }
3216
3217 #[doc(hidden)]
3218 pub fn fill_amplitude_values_and_gradients_public(
3219 &self,
3220 amplitude_values: &mut [Complex64],
3221 gradient_values: &mut [DVector<Complex64>],
3222 active_indices: &[usize],
3223 active_mask: &[bool],
3224 parameters: &Parameters,
3225 cache: &Cache,
3226 ) {
3227 self.fill_amplitude_values_and_gradients(
3228 amplitude_values,
3229 gradient_values,
3230 active_indices,
3231 active_mask,
3232 parameters,
3233 cache,
3234 );
3235 }
3236
3237 #[cfg(feature = "execution-context-prototype")]
3238 #[inline]
3239 fn evaluate_cache_gradient_with_scratch(
3240 &self,
3241 amplitude_values: &mut [Complex64],
3242 gradient_values: &mut [DVector<Complex64>],
3243 value_slots: &mut [Complex64],
3244 gradient_slots: &mut [DVector<Complex64>],
3245 active_indices: &[usize],
3246 active_mask: &[bool],
3247 parameters: &Parameters,
3248 cache: &Cache,
3249 ) -> DVector<Complex64> {
3250 self.fill_amplitude_values_and_gradients(
3251 amplitude_values,
3252 gradient_values,
3253 active_indices,
3254 active_mask,
3255 parameters,
3256 cache,
3257 );
3258 self.evaluate_expression_gradient_with_scratch(
3259 amplitude_values,
3260 gradient_values,
3261 value_slots,
3262 gradient_slots,
3263 )
3264 }
3265
3266 #[cfg(feature = "execution-context-prototype")]
3267 #[allow(dead_code)]
3268 #[inline]
3269 fn evaluate_cache_value_gradient_with_scratch(
3270 &self,
3271 amplitude_values: &mut [Complex64],
3272 gradient_values: &mut [DVector<Complex64>],
3273 value_slots: &mut [Complex64],
3274 gradient_slots: &mut [DVector<Complex64>],
3275 active_indices: &[usize],
3276 active_mask: &[bool],
3277 parameters: &Parameters,
3278 cache: &Cache,
3279 ) -> (Complex64, DVector<Complex64>) {
3280 self.fill_amplitude_values_and_gradients(
3281 amplitude_values,
3282 gradient_values,
3283 active_indices,
3284 active_mask,
3285 parameters,
3286 cache,
3287 );
3288 self.evaluate_expression_value_gradient_with_scratch(
3289 amplitude_values,
3290 gradient_values,
3291 value_slots,
3292 gradient_slots,
3293 )
3294 }
3295
3296 pub fn expression_slot_count(&self) -> usize {
3297 self.lowered_runtime_slot_count()
3298 }
3299 fn compile_expression_ir_for_active_mask(&self, active_mask: &[bool]) -> ir::ExpressionIR {
3300 let amplitude_dependencies = self
3301 .amplitudes
3302 .iter()
3303 .map(|amp| ir::DependenceClass::from(amp.dependence_hint()))
3304 .collect::<Vec<_>>();
3305 let amplitude_realness = self
3306 .amplitudes
3307 .iter()
3308 .map(|amp| amp.real_valued_hint())
3309 .collect::<Vec<_>>();
3310 ir::compile_expression_ir_with_real_hints(
3311 &self.expression,
3312 active_mask,
3313 &litude_dependencies,
3314 &litude_realness,
3315 )
3316 }
3317 fn lower_expression_runtime_for_active_mask(
3318 &self,
3319 active_mask: &[bool],
3320 ) -> LadduResult<lowered::LoweredExpressionRuntime> {
3321 let expression_ir = self.compile_expression_ir_for_active_mask(active_mask);
3322 lowered::LoweredExpressionRuntime::from_ir_value_gradient(&expression_ir).map_err(|error| {
3323 LadduError::Custom(format!(
3324 "Failed to lower active-mask runtime specialization: {error:?}"
3325 ))
3326 })
3327 }
3328 fn ensure_cached_integral_cache_state(
3329 &self,
3330 resources: &Resources,
3331 ) -> LadduResult<Arc<CachedIntegralCacheState>> {
3332 Ok(self
3333 .ensure_expression_specialization(resources)?
3334 .cached_integrals)
3335 }
3336
3337 fn evaluate_expression_runtime_value_with_scratch(
3338 &self,
3339 amplitude_values: &[Complex64],
3340 scratch: &mut [Complex64],
3341 ) -> Complex64 {
3342 let lowered_runtime = self.lowered_runtime();
3343 lowered_runtime
3344 .value_program()
3345 .evaluate_into(amplitude_values, scratch)
3346 }
3347
3348 #[doc(hidden)]
3349 pub fn evaluate_expression_value_with_program_snapshot(
3350 &self,
3351 program_snapshot: &ExpressionValueProgramSnapshot,
3352 amplitude_values: &[Complex64],
3353 scratch: &mut [Complex64],
3354 ) -> Complex64 {
3355 program_snapshot
3356 .lowered_program
3357 .evaluate_into(amplitude_values, scratch)
3358 }
3359
3360 fn evaluate_expression_runtime_gradient_with_scratch(
3361 &self,
3362 amplitude_values: &[Complex64],
3363 gradient_values: &[DVector<Complex64>],
3364 value_scratch: &mut [Complex64],
3365 gradient_scratch: &mut [DVector<Complex64>],
3366 ) -> DVector<Complex64> {
3367 let lowered_runtime = self.lowered_runtime();
3368 lowered_runtime.gradient_program().evaluate_gradient_into(
3369 amplitude_values,
3370 gradient_values,
3371 value_scratch,
3372 gradient_scratch,
3373 )
3374 }
3375
3376 fn evaluate_expression_runtime_value_gradient_with_scratch(
3377 &self,
3378 amplitude_values: &[Complex64],
3379 gradient_values: &[DVector<Complex64>],
3380 value_scratch: &mut [Complex64],
3381 gradient_scratch: &mut [DVector<Complex64>],
3382 ) -> (Complex64, DVector<Complex64>) {
3383 let lowered_runtime = self.lowered_runtime();
3384 lowered_runtime
3385 .value_gradient_program()
3386 .evaluate_value_gradient_into(
3387 amplitude_values,
3388 gradient_values,
3389 value_scratch,
3390 gradient_scratch,
3391 )
3392 }
3393
3394 fn evaluate_expression_runtime_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
3395 let lowered_runtime = self.lowered_runtime();
3396 let program = lowered_runtime.value_program();
3397 let mut scratch = vec![Complex64::ZERO; program.scratch_slots()];
3398 program.evaluate_into(amplitude_values, &mut scratch)
3399 }
3400
3401 fn evaluate_expression_runtime_gradient(
3402 &self,
3403 amplitude_values: &[Complex64],
3404 gradient_values: &[DVector<Complex64>],
3405 ) -> DVector<Complex64> {
3406 let lowered_runtime = self.lowered_runtime();
3407 let program = lowered_runtime.gradient_program();
3408 let mut value_scratch = vec![Complex64::ZERO; program.scratch_slots()];
3409 let grad_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
3410 let mut gradient_scratch = vec![Complex64::ZERO; program.scratch_slots() * grad_dim];
3411 program.evaluate_gradient_into_flat(
3412 amplitude_values,
3413 gradient_values,
3414 &mut value_scratch,
3415 &mut gradient_scratch,
3416 grad_dim,
3417 )
3418 }
3419 pub fn expression_root_dependence(&self) -> LadduResult<ExpressionDependence> {
3421 let resources = self.resources.read();
3422 Ok(self
3423 .ensure_cached_integral_cache_state(&resources)?
3424 .expression_ir
3425 .root_dependence()
3426 .into())
3427 }
3428 pub fn expression_node_dependence_annotations(&self) -> LadduResult<Vec<ExpressionDependence>> {
3430 let resources = self.resources.read();
3431 Ok(self
3432 .ensure_cached_integral_cache_state(&resources)?
3433 .expression_ir
3434 .node_dependence_annotations()
3435 .iter()
3436 .copied()
3437 .map(Into::into)
3438 .collect())
3439 }
3440 pub fn expression_dependence_warnings(&self) -> LadduResult<Vec<String>> {
3442 let resources = self.resources.read();
3443 Ok(self
3444 .ensure_cached_integral_cache_state(&resources)?
3445 .expression_ir
3446 .dependence_warnings()
3447 .to_vec())
3448 }
3449 pub fn expression_normalization_plan_explain(&self) -> LadduResult<NormalizationPlanExplain> {
3451 let resources = self.resources.read();
3452 Ok(self
3453 .ensure_cached_integral_cache_state(&resources)?
3454 .expression_ir
3455 .normalization_plan_explain()
3456 .into())
3457 }
3458 pub fn expression_normalization_execution_sets(
3460 &self,
3461 ) -> LadduResult<NormalizationExecutionSetsExplain> {
3462 let resources = self.resources.read();
3463 Ok(self
3464 .ensure_cached_integral_cache_state(&resources)?
3465 .execution_sets
3466 .clone()
3467 .into())
3468 }
3469 pub fn expression_precomputed_cached_integrals(
3471 &self,
3472 ) -> LadduResult<Vec<PrecomputedCachedIntegral>> {
3473 let resources = self.resources.read();
3474 Ok(self
3475 .ensure_cached_integral_cache_state(&resources)?
3476 .values
3477 .clone())
3478 }
3479 pub fn expression_precomputed_cached_integral_gradient_terms(
3484 &self,
3485 parameters: &[f64],
3486 ) -> LadduResult<Vec<PrecomputedCachedIntegralGradientTerm>> {
3487 let resources = self.resources.read();
3488 let state = self.ensure_cached_integral_cache_state(&resources)?;
3489 if state.values.is_empty() {
3490 return Ok(Vec::new());
3491 }
3492
3493 let Some(cache) = resources.caches.first() else {
3494 return Ok(state
3495 .values
3496 .iter()
3497 .map(|descriptor| PrecomputedCachedIntegralGradientTerm {
3498 mul_node_index: descriptor.mul_node_index,
3499 parameter_node_index: descriptor.parameter_node_index,
3500 cache_node_index: descriptor.cache_node_index,
3501 coefficient: descriptor.coefficient,
3502 weighted_gradient: DVector::zeros(parameters.len()),
3503 })
3504 .collect());
3505 };
3506
3507 let parameter_values = resources.parameter_map.assemble(parameters)?;
3508 let mut amplitude_values = vec![Complex64::ZERO; self.amplitudes.len()];
3509 self.fill_amplitude_values(
3510 &mut amplitude_values,
3511 resources.active_indices(),
3512 ¶meter_values,
3513 cache,
3514 );
3515 let mut amplitude_gradients = (0..self.amplitudes.len())
3516 .map(|_| DVector::zeros(parameters.len()))
3517 .collect::<Vec<_>>();
3518 self.fill_amplitude_gradients(
3519 &mut amplitude_gradients,
3520 &resources.active,
3521 ¶meter_values,
3522 cache,
3523 );
3524 let lowered_artifacts = self.active_lowered_artifacts();
3525 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3526 let mut gradient_slots = (0..state.expression_ir.node_count())
3527 .map(|_| DVector::zeros(parameters.len()))
3528 .collect::<Vec<_>>();
3529 let max_lowered_slots = lowered_artifacts
3530 .as_ref()
3531 .map(|artifacts| {
3532 artifacts
3533 .lowered_parameter_factors
3534 .iter()
3535 .filter_map(|runtime| {
3536 runtime
3537 .as_ref()
3538 .and_then(|runtime| runtime.gradient_program())
3539 .map(|program| program.scratch_slots())
3540 })
3541 .max()
3542 .unwrap_or(0)
3543 })
3544 .unwrap_or(0);
3545 let mut lowered_value_slots = vec![Complex64::ZERO; max_lowered_slots];
3546 let mut lowered_gradient_slots = vec![DVector::zeros(parameters.len()); max_lowered_slots];
3547 let use_lowered = lowered_artifacts.as_ref().is_some_and(|artifacts| {
3548 artifacts.lowered_parameter_factors.len() == state.values.len()
3549 && artifacts.lowered_parameter_factors.iter().all(|runtime| {
3550 runtime
3551 .as_ref()
3552 .and_then(|runtime| runtime.gradient_program())
3553 .is_some()
3554 })
3555 });
3556
3557 if !use_lowered {
3558 let _ = state.expression_ir.evaluate_gradient_into(
3559 &litude_values,
3560 &litude_gradients,
3561 &mut value_slots,
3562 &mut gradient_slots,
3563 );
3564 }
3565
3566 if use_lowered {
3567 let lowered_artifacts = lowered_artifacts.expect("lowered artifacts should exist");
3568 Ok(state
3569 .values
3570 .iter()
3571 .cloned()
3572 .zip(lowered_artifacts.lowered_parameter_factors.iter())
3573 .map(|(descriptor, runtime)| {
3574 let parameter_gradient = runtime
3575 .as_ref()
3576 .and_then(|runtime| runtime.gradient_program())
3577 .map(|program| {
3578 program.evaluate_gradient_into(
3579 &litude_values,
3580 &litude_gradients,
3581 &mut lowered_value_slots[..program.scratch_slots()],
3582 &mut lowered_gradient_slots[..program.scratch_slots()],
3583 )
3584 })
3585 .unwrap_or_else(|| DVector::zeros(parameters.len()));
3586 let weighted_gradient = parameter_gradient.map(|value| {
3587 value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
3588 });
3589 PrecomputedCachedIntegralGradientTerm {
3590 mul_node_index: descriptor.mul_node_index,
3591 parameter_node_index: descriptor.parameter_node_index,
3592 cache_node_index: descriptor.cache_node_index,
3593 coefficient: descriptor.coefficient,
3594 weighted_gradient,
3595 }
3596 })
3597 .collect())
3598 } else {
3599 Ok(state
3600 .values
3601 .iter()
3602 .map(|descriptor| {
3603 let parameter_gradient = gradient_slots
3604 .get(descriptor.parameter_node_index)
3605 .cloned()
3606 .unwrap_or_else(|| DVector::zeros(parameters.len()));
3607 let weighted_gradient = parameter_gradient.map(|value| {
3608 value * descriptor.weighted_cache_sum * descriptor.coefficient as f64
3609 });
3610 PrecomputedCachedIntegralGradientTerm {
3611 mul_node_index: descriptor.mul_node_index,
3612 parameter_node_index: descriptor.parameter_node_index,
3613 cache_node_index: descriptor.cache_node_index,
3614 coefficient: descriptor.coefficient,
3615 weighted_gradient,
3616 }
3617 })
3618 .collect())
3619 }
3620 }
3621 fn evaluate_cached_weighted_value_sum_ir(
3622 &self,
3623 state: &CachedIntegralCacheState,
3624 amplitude_values: &[Complex64],
3625 ) -> f64 {
3626 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3627 let _ = state
3628 .expression_ir
3629 .evaluate_into(amplitude_values, &mut value_slots);
3630 state
3631 .values
3632 .iter()
3633 .map(|descriptor| {
3634 let parameter_factor = value_slots[descriptor.parameter_node_index];
3635 (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
3636 .re
3637 })
3638 .sum()
3639 }
3640 fn evaluate_cached_weighted_value_sum_lowered(
3641 &self,
3642 state: &CachedIntegralCacheState,
3643 lowered_artifacts: &LoweredArtifactCacheState,
3644 amplitude_values: &[Complex64],
3645 ) -> Option<f64> {
3646 let max_slots = lowered_artifacts
3647 .lowered_parameter_factors
3648 .iter()
3649 .filter_map(|runtime| {
3650 runtime
3651 .as_ref()
3652 .and_then(|runtime| runtime.value_program())
3653 .map(|program| program.scratch_slots())
3654 })
3655 .max()
3656 .unwrap_or(0);
3657 let mut value_slots = vec![Complex64::ZERO; max_slots];
3658 let mut total = 0.0;
3659 for (descriptor, runtime) in state
3660 .values
3661 .iter()
3662 .zip(lowered_artifacts.lowered_parameter_factors.iter())
3663 {
3664 let parameter_factor = runtime
3665 .as_ref()
3666 .and_then(|runtime| runtime.value_program())
3667 .map(|program| {
3668 program.evaluate_into(
3669 amplitude_values,
3670 &mut value_slots[..program.scratch_slots()],
3671 )
3672 })?;
3673 total +=
3674 (parameter_factor * descriptor.weighted_cache_sum * descriptor.coefficient as f64)
3675 .re;
3676 }
3677 Some(total)
3678 }
3679 fn evaluate_cached_weighted_gradient_sum_ir(
3680 &self,
3681 state: &CachedIntegralCacheState,
3682 amplitude_values: &[Complex64],
3683 amplitude_gradients: &[DVector<Complex64>],
3684 grad_dim: usize,
3685 ) -> DVector<f64> {
3686 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3687 let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
3688 let _ = state.expression_ir.evaluate_gradient_into(
3689 amplitude_values,
3690 amplitude_gradients,
3691 &mut value_slots,
3692 &mut gradient_slots,
3693 );
3694 state
3695 .values
3696 .iter()
3697 .fold(DVector::zeros(grad_dim), |mut accum, descriptor| {
3698 let parameter_gradient = &gradient_slots[descriptor.parameter_node_index];
3699 let coefficient = descriptor.coefficient as f64;
3700 for (accum_item, gradient_item) in accum.iter_mut().zip(parameter_gradient.iter()) {
3701 *accum_item +=
3702 (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
3703 }
3704 accum
3705 })
3706 }
3707 fn evaluate_cached_weighted_gradient_sum_lowered(
3708 &self,
3709 state: &CachedIntegralCacheState,
3710 lowered_artifacts: &LoweredArtifactCacheState,
3711 amplitude_values: &[Complex64],
3712 amplitude_gradients: &[DVector<Complex64>],
3713 grad_dim: usize,
3714 ) -> Option<DVector<f64>> {
3715 let max_value_slots = lowered_artifacts
3716 .lowered_parameter_factors
3717 .iter()
3718 .filter_map(|runtime| {
3719 runtime
3720 .as_ref()
3721 .and_then(|runtime| runtime.gradient_program())
3722 .map(|program| program.scratch_slots())
3723 })
3724 .max()
3725 .unwrap_or(0);
3726 let mut value_slots = vec![Complex64::ZERO; max_value_slots];
3727 let mut gradient_slots = vec![Complex64::ZERO; max_value_slots * grad_dim];
3728 let mut total = DVector::zeros(grad_dim);
3729 for (descriptor, runtime) in state
3730 .values
3731 .iter()
3732 .zip(lowered_artifacts.lowered_parameter_factors.iter())
3733 {
3734 let parameter_gradient = runtime
3735 .as_ref()
3736 .and_then(|runtime| runtime.gradient_program())
3737 .map(|program| {
3738 program.evaluate_gradient_into_flat(
3739 amplitude_values,
3740 amplitude_gradients,
3741 &mut value_slots[..program.scratch_slots()],
3742 &mut gradient_slots[..program.scratch_slots() * grad_dim],
3743 grad_dim,
3744 )
3745 })?;
3746 let coefficient = descriptor.coefficient as f64;
3747 for (accum_item, gradient_item) in total.iter_mut().zip(parameter_gradient.iter()) {
3748 *accum_item += (*gradient_item * descriptor.weighted_cache_sum * coefficient).re;
3749 }
3750 }
3751 Some(total)
3752 }
3753 fn evaluate_residual_value_ir(
3754 &self,
3755 state: &CachedIntegralCacheState,
3756 amplitude_values: &[Complex64],
3757 ) -> Complex64 {
3758 let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
3759 for descriptor in &state.values {
3760 if descriptor.mul_node_index < zeroed_nodes.len() {
3761 zeroed_nodes[descriptor.mul_node_index] = true;
3762 }
3763 }
3764 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3765 state.expression_ir.evaluate_into_with_zeroed_nodes(
3766 amplitude_values,
3767 &mut value_slots,
3768 &zeroed_nodes,
3769 )
3770 }
3771 fn evaluate_residual_gradient_ir(
3772 &self,
3773 state: &CachedIntegralCacheState,
3774 amplitude_values: &[Complex64],
3775 amplitude_gradients: &[DVector<Complex64>],
3776 grad_dim: usize,
3777 ) -> DVector<Complex64> {
3778 let mut zeroed_nodes = vec![false; state.expression_ir.node_count()];
3779 for descriptor in &state.values {
3780 if descriptor.mul_node_index < zeroed_nodes.len() {
3781 zeroed_nodes[descriptor.mul_node_index] = true;
3782 }
3783 }
3784 let mut value_slots = vec![Complex64::ZERO; state.expression_ir.node_count()];
3785 let mut gradient_slots = vec![DVector::zeros(grad_dim); state.expression_ir.node_count()];
3786 state
3787 .expression_ir
3788 .evaluate_gradient_into_with_zeroed_nodes(
3789 amplitude_values,
3790 amplitude_gradients,
3791 &mut value_slots,
3792 &mut gradient_slots,
3793 &zeroed_nodes,
3794 )
3795 }
3796
3797 fn evaluate_weighted_value_sum_local_components(
3798 &self,
3799 parameters: &[f64],
3800 ) -> LadduResult<(f64, f64)> {
3801 let resources = self.resources.read();
3802 let parameters = resources.parameter_map.assemble(parameters)?;
3803 let amplitude_len = self.amplitudes.len();
3804 let state = self.ensure_cached_integral_cache_state(&resources)?;
3805 let lowered_artifacts = self.active_lowered_artifacts();
3806 let residual_value_slot_count = lowered_artifacts
3807 .as_ref()
3808 .and_then(|artifacts| {
3809 artifacts
3810 .residual_runtime
3811 .as_ref()
3812 .map(|runtime| runtime.value_program())
3813 .map(|program| program.scratch_slots())
3814 })
3815 .unwrap_or_else(|| self.expression_slot_count());
3816 let residual_value_program = lowered_artifacts
3817 .as_ref()
3818 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
3819 .map(|runtime| runtime.value_program());
3820 let cached_parameter_indices = &state.execution_sets.cached_parameter_amplitudes;
3821 let residual_active_indices = &state.execution_sets.residual_amplitudes;
3822 debug_assert!(cached_parameter_indices.iter().all(|&index| resources
3823 .active
3824 .get(index)
3825 .copied()
3826 .unwrap_or(false)));
3827 debug_assert!(residual_active_indices.iter().all(|&index| resources
3828 .active
3829 .get(index)
3830 .copied()
3831 .unwrap_or(false)));
3832 let cached_value_sum = {
3833 if let Some(cache) = resources.caches.first() {
3834 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3835 self.fill_amplitude_values(
3836 &mut amplitude_values,
3837 cached_parameter_indices,
3838 ¶meters,
3839 cache,
3840 );
3841 lowered_artifacts
3842 .as_ref()
3843 .and_then(|artifacts| {
3844 self.evaluate_cached_weighted_value_sum_lowered(
3845 &state,
3846 artifacts,
3847 &litude_values,
3848 )
3849 })
3850 .unwrap_or_else(|| {
3851 self.evaluate_cached_weighted_value_sum_ir(&state, &litude_values)
3852 })
3853 } else {
3854 0.0
3855 }
3856 };
3857
3858 #[cfg(feature = "rayon")]
3859 let residual_sum: f64 = {
3860 resources
3861 .caches
3862 .par_iter()
3863 .zip(self.dataset.events_local().par_iter())
3864 .map_init(
3865 || {
3866 (
3867 vec![Complex64::ZERO; amplitude_len],
3868 vec![Complex64::ZERO; residual_value_slot_count],
3869 )
3870 },
3871 |(amplitude_values, value_slots), (cache, event)| {
3872 self.fill_amplitude_values(
3873 amplitude_values,
3874 residual_active_indices,
3875 ¶meters,
3876 cache,
3877 );
3878 {
3879 let value = residual_value_program
3880 .as_ref()
3881 .map(|program| {
3882 program.evaluate_into(
3883 amplitude_values,
3884 &mut value_slots[..program.scratch_slots()],
3885 )
3886 })
3887 .unwrap_or_else(|| {
3888 self.evaluate_residual_value_ir(&state, amplitude_values)
3889 });
3890 event.weight * value.re
3891 }
3892 },
3893 )
3894 .sum()
3895 };
3896
3897 #[cfg(not(feature = "rayon"))]
3898 let residual_sum: f64 = {
3899 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
3900 let mut value_slots = vec![Complex64::ZERO; residual_value_slot_count];
3901 resources
3902 .caches
3903 .iter()
3904 .zip(self.dataset.events_local().iter())
3905 .map(|(cache, event)| {
3906 self.fill_amplitude_values(
3907 &mut amplitude_values,
3908 &residual_active_indices,
3909 ¶meters,
3910 cache,
3911 );
3912 {
3913 let value = residual_value_program
3914 .as_ref()
3915 .map(|program| {
3916 program.evaluate_into(
3917 &litude_values,
3918 &mut value_slots[..program.scratch_slots()],
3919 )
3920 })
3921 .unwrap_or_else(|| {
3922 self.evaluate_residual_value_ir(&state, &litude_values)
3923 });
3924 event.weight * value.re
3925 }
3926 })
3927 .sum()
3928 };
3929 Ok((residual_sum, cached_value_sum))
3930 }
3931
3932 pub fn evaluate_weighted_value_sum_local(&self, parameters: &[f64]) -> LadduResult<f64> {
3936 let (residual_sum, cached_value_sum) =
3937 self.evaluate_weighted_value_sum_local_components(parameters)?;
3938 Ok(residual_sum + cached_value_sum)
3939 }
3940
3941 #[cfg(feature = "mpi")]
3942 pub fn evaluate_weighted_value_sum_mpi(
3946 &self,
3947 parameters: &[f64],
3948 world: &SimpleCommunicator,
3949 ) -> LadduResult<f64> {
3950 let (residual_sum_local, cached_value_sum_local) =
3951 self.evaluate_weighted_value_sum_local_components(parameters)?;
3952 let mut residual_sum = 0.0;
3953 world.all_reduce_into(
3954 &residual_sum_local,
3955 &mut residual_sum,
3956 mpi::collective::SystemOperation::sum(),
3957 );
3958 let mut cached_value_sum = 0.0;
3959 world.all_reduce_into(
3960 &cached_value_sum_local,
3961 &mut cached_value_sum,
3962 mpi::collective::SystemOperation::sum(),
3963 );
3964 Ok(residual_sum + cached_value_sum)
3965 }
3966
3967 fn evaluate_weighted_gradient_sum_local_components(
3971 &self,
3972 parameters: &[f64],
3973 ) -> LadduResult<(DVector<f64>, DVector<f64>)> {
3974 let resources = self.resources.read();
3975 let parameters = resources.parameter_map.assemble(parameters)?;
3976 let amplitude_len = self.amplitudes.len();
3977 let grad_dim = parameters.len();
3978 let state = self.ensure_cached_integral_cache_state(&resources)?;
3979 let lowered_artifacts = self.active_lowered_artifacts();
3980 let active_index_set = resources.active_indices();
3981 let cached_parameter_indices = state
3982 .execution_sets
3983 .cached_parameter_amplitudes
3984 .iter()
3985 .copied()
3986 .filter(|index| active_index_set.binary_search(index).is_ok())
3987 .collect::<Vec<_>>();
3988 let residual_active_indices = state
3989 .execution_sets
3990 .residual_amplitudes
3991 .iter()
3992 .copied()
3993 .filter(|index| active_index_set.binary_search(index).is_ok())
3994 .collect::<Vec<_>>();
3995 let mut cached_parameter_mask = vec![false; amplitude_len];
3996 for &index in &cached_parameter_indices {
3997 cached_parameter_mask[index] = true;
3998 }
3999 let mut residual_active_mask = vec![false; amplitude_len];
4000 for &index in &residual_active_indices {
4001 residual_active_mask[index] = true;
4002 }
4003 let residual_gradient_program = lowered_artifacts
4004 .as_ref()
4005 .and_then(|artifacts| artifacts.residual_runtime.as_ref())
4006 .map(|runtime| runtime.gradient_program());
4007 let residual_gradient_slot_count = residual_gradient_program
4008 .as_ref()
4009 .map(|program| program.scratch_slots())
4010 .unwrap_or_else(|| state.expression_ir.node_count());
4011 let cached_term_sum = {
4012 if let Some(cache) = resources.caches.first() {
4013 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4014 self.fill_amplitude_values(
4015 &mut amplitude_values,
4016 &cached_parameter_indices,
4017 ¶meters,
4018 cache,
4019 );
4020 let mut amplitude_gradients = (0..amplitude_len)
4021 .map(|_| DVector::zeros(grad_dim))
4022 .collect::<Vec<_>>();
4023 self.fill_amplitude_gradients(
4024 &mut amplitude_gradients,
4025 &cached_parameter_mask,
4026 ¶meters,
4027 cache,
4028 );
4029 lowered_artifacts
4030 .as_ref()
4031 .and_then(|artifacts| {
4032 self.evaluate_cached_weighted_gradient_sum_lowered(
4033 &state,
4034 artifacts,
4035 &litude_values,
4036 &litude_gradients,
4037 grad_dim,
4038 )
4039 })
4040 .unwrap_or_else(|| {
4041 self.evaluate_cached_weighted_gradient_sum_ir(
4042 &state,
4043 &litude_values,
4044 &litude_gradients,
4045 grad_dim,
4046 )
4047 })
4048 } else {
4049 DVector::zeros(grad_dim)
4050 }
4051 };
4052
4053 #[cfg(feature = "rayon")]
4054 let residual_sum = {
4055 resources
4056 .caches
4057 .par_iter()
4058 .zip(self.dataset.events_local().par_iter())
4059 .map_init(
4060 || {
4061 (
4062 vec![Complex64::ZERO; amplitude_len],
4063 vec![DVector::zeros(grad_dim); amplitude_len],
4064 vec![Complex64::ZERO; residual_gradient_slot_count],
4065 vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim],
4066 )
4067 },
4068 |(amplitude_values, gradient_values, value_slots, gradient_slots),
4069 (cache, event)| {
4070 self.fill_amplitude_values_and_gradients(
4071 amplitude_values,
4072 gradient_values,
4073 &residual_active_indices,
4074 &residual_active_mask,
4075 ¶meters,
4076 cache,
4077 );
4078 let gradient = residual_gradient_program
4079 .as_ref()
4080 .map(|program| {
4081 program.evaluate_gradient_into_flat(
4082 amplitude_values,
4083 gradient_values,
4084 value_slots,
4085 gradient_slots,
4086 grad_dim,
4087 )
4088 })
4089 .unwrap_or_else(|| {
4090 self.evaluate_residual_gradient_ir(
4091 &state,
4092 amplitude_values,
4093 gradient_values,
4094 grad_dim,
4095 )
4096 });
4097 gradient.map(|value| value.re).scale(event.weight)
4098 },
4099 )
4100 .reduce(
4101 || DVector::zeros(grad_dim),
4102 |mut accum, value| {
4103 accum += value;
4104 accum
4105 },
4106 )
4107 };
4108
4109 #[cfg(not(feature = "rayon"))]
4110 let residual_sum = {
4111 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4112 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4113 let mut value_slots = vec![Complex64::ZERO; residual_gradient_slot_count];
4114 let mut gradient_slots = vec![Complex64::ZERO; residual_gradient_slot_count * grad_dim];
4115 resources
4116 .caches
4117 .iter()
4118 .zip(self.dataset.events_local().iter())
4119 .map(|(cache, event)| {
4120 self.fill_amplitude_values_and_gradients(
4121 &mut amplitude_values,
4122 &mut gradient_values,
4123 &residual_active_indices,
4124 &residual_active_mask,
4125 ¶meters,
4126 cache,
4127 );
4128 let gradient = residual_gradient_program
4129 .as_ref()
4130 .map(|program| {
4131 program.evaluate_gradient_into_flat(
4132 &litude_values,
4133 &gradient_values,
4134 &mut value_slots,
4135 &mut gradient_slots,
4136 grad_dim,
4137 )
4138 })
4139 .unwrap_or_else(|| {
4140 self.evaluate_residual_gradient_ir(
4141 &state,
4142 &litude_values,
4143 &gradient_values,
4144 grad_dim,
4145 )
4146 });
4147 gradient.map(|value| value.re).scale(event.weight)
4148 })
4149 .sum()
4150 };
4151 Ok((residual_sum, cached_term_sum))
4152 }
4153
4154 pub fn evaluate_weighted_gradient_sum_local(
4158 &self,
4159 parameters: &[f64],
4160 ) -> LadduResult<DVector<f64>> {
4161 let (residual_sum, cached_term_sum) =
4162 self.evaluate_weighted_gradient_sum_local_components(parameters)?;
4163 Ok(residual_sum + cached_term_sum)
4164 }
4165
4166 #[cfg(feature = "mpi")]
4167 pub fn evaluate_weighted_gradient_sum_mpi(
4171 &self,
4172 parameters: &[f64],
4173 world: &SimpleCommunicator,
4174 ) -> LadduResult<DVector<f64>> {
4175 let (residual_sum_local, cached_term_sum_local) =
4176 self.evaluate_weighted_gradient_sum_local_components(parameters)?;
4177 let mut residual_sum = vec![0.0; residual_sum_local.len()];
4178 world.all_reduce_into(
4179 residual_sum_local.as_slice(),
4180 &mut residual_sum,
4181 mpi::collective::SystemOperation::sum(),
4182 );
4183 let mut cached_term_sum = vec![0.0; cached_term_sum_local.len()];
4184 world.all_reduce_into(
4185 cached_term_sum_local.as_slice(),
4186 &mut cached_term_sum,
4187 mpi::collective::SystemOperation::sum(),
4188 );
4189 let mut total = DVector::from_vec(residual_sum);
4190 total += DVector::from_vec(cached_term_sum);
4191 Ok(total)
4192 }
4193
4194 pub fn evaluate_expression_value_with_scratch(
4195 &self,
4196 amplitude_values: &[Complex64],
4197 scratch: &mut [Complex64],
4198 ) -> Complex64 {
4199 self.evaluate_expression_runtime_value_with_scratch(amplitude_values, scratch)
4200 }
4201
4202 pub fn evaluate_expression_gradient_with_scratch(
4203 &self,
4204 amplitude_values: &[Complex64],
4205 gradient_values: &[DVector<Complex64>],
4206 value_scratch: &mut [Complex64],
4207 gradient_scratch: &mut [DVector<Complex64>],
4208 ) -> DVector<Complex64> {
4209 self.evaluate_expression_runtime_gradient_with_scratch(
4210 amplitude_values,
4211 gradient_values,
4212 value_scratch,
4213 gradient_scratch,
4214 )
4215 }
4216
4217 pub fn evaluate_expression_value_gradient_with_scratch(
4218 &self,
4219 amplitude_values: &[Complex64],
4220 gradient_values: &[DVector<Complex64>],
4221 value_scratch: &mut [Complex64],
4222 gradient_scratch: &mut [DVector<Complex64>],
4223 ) -> (Complex64, DVector<Complex64>) {
4224 self.evaluate_expression_runtime_value_gradient_with_scratch(
4225 amplitude_values,
4226 gradient_values,
4227 value_scratch,
4228 gradient_scratch,
4229 )
4230 }
4231
4232 pub fn evaluate_expression_value(&self, amplitude_values: &[Complex64]) -> Complex64 {
4233 self.evaluate_expression_runtime_value(amplitude_values)
4234 }
4235
4236 pub fn evaluate_expression_gradient(
4237 &self,
4238 amplitude_values: &[Complex64],
4239 gradient_values: &[DVector<Complex64>],
4240 ) -> DVector<Complex64> {
4241 self.evaluate_expression_runtime_gradient(amplitude_values, gradient_values)
4242 }
4243
4244 pub fn parameters(&self) -> Vec<String> {
4247 self.resources.read().parameter_names()
4248 }
4249
4250 pub fn free_parameters(&self) -> Vec<String> {
4252 self.resources.read().free_parameter_names()
4253 }
4254
4255 pub fn fixed_parameters(&self) -> Vec<String> {
4257 self.resources.read().fixed_parameter_names()
4258 }
4259
4260 pub fn n_free(&self) -> usize {
4262 self.resources.read().n_free_parameters()
4263 }
4264
4265 pub fn n_fixed(&self) -> usize {
4267 self.resources.read().n_fixed_parameters()
4268 }
4269
4270 pub fn n_parameters(&self) -> usize {
4272 self.resources.read().n_parameters()
4273 }
4274
4275 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
4276 self.resources.read().fix_parameter(name, value)
4277 }
4278
4279 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
4280 self.resources.read().free_parameter(name)
4281 }
4282
4283 pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<()> {
4284 self.resources.write().rename_parameter(old, new)
4285 }
4286
4287 pub fn rename_parameters(&self, mapping: &HashMap<String, String>) -> LadduResult<()> {
4288 self.resources.write().rename_parameters(mapping)
4289 }
4290
4291 pub fn activate<T: AsRef<str>>(&self, name: T) {
4293 self.resources.write().activate(name);
4294 self.refresh_runtime_specializations();
4295 }
4296 pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
4298 self.resources.write().activate_strict(name)?;
4299 self.refresh_runtime_specializations();
4300 Ok(())
4301 }
4302
4303 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
4305 self.resources.write().activate_many(names);
4306 self.refresh_runtime_specializations();
4307 }
4308 pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
4310 self.resources.write().activate_many_strict(names)?;
4311 self.refresh_runtime_specializations();
4312 Ok(())
4313 }
4314
4315 pub fn activate_all(&self) {
4317 self.resources.write().activate_all();
4318 self.refresh_runtime_specializations();
4319 }
4320
4321 pub fn deactivate<T: AsRef<str>>(&self, name: T) {
4323 self.resources.write().deactivate(name);
4324 self.refresh_runtime_specializations();
4325 }
4326
4327 pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
4329 self.resources.write().deactivate_strict(name)?;
4330 self.refresh_runtime_specializations();
4331 Ok(())
4332 }
4333
4334 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
4336 self.resources.write().deactivate_many(names);
4337 self.refresh_runtime_specializations();
4338 }
4339 pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
4341 self.resources.write().deactivate_many_strict(names)?;
4342 self.refresh_runtime_specializations();
4343 Ok(())
4344 }
4345
4346 pub fn deactivate_all(&self) {
4348 self.resources.write().deactivate_all();
4349 self.refresh_runtime_specializations();
4350 }
4351
4352 pub fn isolate<T: AsRef<str>>(&self, name: T) {
4354 self.resources.write().isolate(name);
4355 self.refresh_runtime_specializations();
4356 }
4357
4358 pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
4360 self.resources.write().isolate_strict(name)?;
4361 self.refresh_runtime_specializations();
4362 Ok(())
4363 }
4364
4365 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
4367 self.resources.write().isolate_many(names);
4368 self.refresh_runtime_specializations();
4369 }
4370
4371 pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
4373 self.resources.write().isolate_many_strict(names)?;
4374 self.refresh_runtime_specializations();
4375 Ok(())
4376 }
4377
4378 pub fn active_mask(&self) -> Vec<bool> {
4380 self.resources.read().active.clone()
4381 }
4382
4383 pub fn set_active_mask(&self, mask: &[bool]) -> LadduResult<()> {
4385 let resources = {
4386 let mut resources = self.resources.write();
4387 if mask.len() != resources.active.len() {
4388 return Err(LadduError::LengthMismatch {
4389 context: "active amplitude mask".to_string(),
4390 expected: resources.active.len(),
4391 actual: mask.len(),
4392 });
4393 }
4394 resources.active.clone_from_slice(mask);
4395 resources.refresh_active_indices();
4396 resources.clone()
4397 };
4398 self.rebuild_runtime_specializations(&resources);
4399 Ok(())
4400 }
4401
4402 pub fn evaluate_local(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
4410 let resources = self.resources.read();
4411 let parameters = resources.parameter_map.assemble(parameters)?;
4412 let amplitude_len = self.amplitudes.len();
4413 let active_indices = resources.active_indices().to_vec();
4414 let slot_count = self.expression_value_slot_count();
4415 let program_snapshot = self.expression_value_program_snapshot();
4416 #[cfg(feature = "rayon")]
4417 {
4418 Ok(resources
4419 .caches
4420 .par_iter()
4421 .map_init(
4422 || {
4423 (
4424 vec![Complex64::ZERO; amplitude_len],
4425 vec![Complex64::ZERO; slot_count],
4426 )
4427 },
4428 |(amplitude_values, expr_slots), cache| {
4429 self.fill_amplitude_values(
4430 amplitude_values,
4431 &active_indices,
4432 ¶meters,
4433 cache,
4434 );
4435 self.evaluate_expression_value_with_program_snapshot(
4436 &program_snapshot,
4437 amplitude_values,
4438 expr_slots,
4439 )
4440 },
4441 )
4442 .collect())
4443 }
4444 #[cfg(not(feature = "rayon"))]
4445 {
4446 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4447 let mut expr_slots = vec![Complex64::ZERO; slot_count];
4448 Ok(resources
4449 .caches
4450 .iter()
4451 .map(|cache| {
4452 self.fill_amplitude_values(
4453 &mut amplitude_values,
4454 &active_indices,
4455 ¶meters,
4456 cache,
4457 );
4458 self.evaluate_expression_value_with_program_snapshot(
4459 &program_snapshot,
4460 &litude_values,
4461 &mut expr_slots,
4462 )
4463 })
4464 .collect())
4465 }
4466 }
4467
4468 pub fn evaluate_local_with_active_mask(
4470 &self,
4471 parameters: &[f64],
4472 active_mask: &[bool],
4473 ) -> LadduResult<Vec<Complex64>> {
4474 let resources = self.resources.read();
4475 if active_mask.len() != resources.active.len() {
4476 return Err(LadduError::LengthMismatch {
4477 context: "active amplitude mask".to_string(),
4478 expected: resources.active.len(),
4479 actual: active_mask.len(),
4480 });
4481 }
4482 let parameters = resources.parameter_map.assemble(parameters)?;
4483 let amplitude_len = self.amplitudes.len();
4484 let active_indices = active_mask
4485 .iter()
4486 .enumerate()
4487 .filter_map(|(index, &active)| if active { Some(index) } else { None })
4488 .collect::<Vec<_>>();
4489 let program_snapshot =
4490 self.expression_value_program_snapshot_for_active_mask(active_mask)?;
4491 let slot_count = self.expression_value_program_snapshot_slot_count(&program_snapshot);
4492 #[cfg(feature = "rayon")]
4493 {
4494 Ok(resources
4495 .caches
4496 .par_iter()
4497 .map_init(
4498 || {
4499 (
4500 vec![Complex64::ZERO; amplitude_len],
4501 vec![Complex64::ZERO; slot_count],
4502 )
4503 },
4504 |(amplitude_values, expr_slots), cache| {
4505 self.fill_amplitude_values(
4506 amplitude_values,
4507 &active_indices,
4508 ¶meters,
4509 cache,
4510 );
4511 self.evaluate_expression_value_with_program_snapshot(
4512 &program_snapshot,
4513 amplitude_values,
4514 expr_slots,
4515 )
4516 },
4517 )
4518 .collect())
4519 }
4520 #[cfg(not(feature = "rayon"))]
4521 {
4522 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4523 let mut expr_slots = vec![Complex64::ZERO; slot_count];
4524 Ok(resources
4525 .caches
4526 .iter()
4527 .map(|cache| {
4528 self.fill_amplitude_values(
4529 &mut amplitude_values,
4530 &active_indices,
4531 ¶meters,
4532 cache,
4533 );
4534 self.evaluate_expression_value_with_program_snapshot(
4535 &program_snapshot,
4536 &litude_values,
4537 &mut expr_slots,
4538 )
4539 })
4540 .collect())
4541 }
4542 }
4543
4544 #[cfg(feature = "execution-context-prototype")]
4546 pub fn evaluate_local_with_ctx(
4547 &self,
4548 parameters: &[f64],
4549 execution_context: &ExecutionContext,
4550 ) -> Vec<Complex64> {
4551 let resources = self.resources.read();
4552 let parameters = resources
4553 .parameter_map
4554 .assemble(parameters)
4555 .expect("parameter slice must match evaluator resources");
4556 let amplitude_len = self.amplitudes.len();
4557 let active_indices = resources.active_indices().to_vec();
4558 let slot_count = self.expression_value_slot_count();
4559 let program_snapshot = self.expression_value_program_snapshot();
4560 #[cfg(feature = "rayon")]
4561 {
4562 if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
4563 return execution_context.install(|| {
4564 resources
4565 .caches
4566 .par_iter()
4567 .map_init(
4568 || {
4569 (
4570 vec![Complex64::ZERO; amplitude_len],
4571 vec![Complex64::ZERO; slot_count],
4572 )
4573 },
4574 |(amplitude_values, expr_slots), cache| {
4575 self.fill_amplitude_values(
4576 amplitude_values,
4577 &active_indices,
4578 ¶meters,
4579 cache,
4580 );
4581 self.evaluate_expression_value_with_program_snapshot(
4582 &program_snapshot,
4583 amplitude_values,
4584 expr_slots,
4585 )
4586 },
4587 )
4588 .collect()
4589 });
4590 }
4591 }
4592 execution_context.with_scratch(|scratch| {
4593 let (amplitude_values, expr_slots) =
4594 scratch.reserve_value_workspaces(amplitude_len, slot_count);
4595 resources
4596 .caches
4597 .iter()
4598 .map(|cache| {
4599 self.fill_amplitude_values(
4600 amplitude_values,
4601 &active_indices,
4602 ¶meters,
4603 cache,
4604 );
4605 self.evaluate_expression_value_with_program_snapshot(
4606 &program_snapshot,
4607 amplitude_values,
4608 expr_slots,
4609 )
4610 })
4611 .collect()
4612 })
4613 }
4614
4615 #[cfg(feature = "mpi")]
4623 fn evaluate_mpi(
4624 &self,
4625 parameters: &[f64],
4626 world: &SimpleCommunicator,
4627 ) -> LadduResult<Vec<Complex64>> {
4628 let local_evaluation = self.evaluate_local(parameters)?;
4629 let n_events = self.dataset.n_events();
4630 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
4631 let (counts, displs) = world.get_counts_displs(n_events);
4632 {
4633 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4636 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
4637 }
4638 Ok(buffer)
4639 }
4640
4641 #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
4642 fn evaluate_mpi_with_ctx(
4643 &self,
4644 parameters: &[f64],
4645 world: &SimpleCommunicator,
4646 execution_context: &ExecutionContext,
4647 ) -> Vec<Complex64> {
4648 let local_evaluation = self.evaluate_local_with_ctx(parameters, execution_context);
4649 let n_events = self.dataset.n_events();
4650 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
4651 let (counts, displs) = world.get_counts_displs(n_events);
4652 {
4653 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4656 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
4657 }
4658 buffer
4659 }
4660
4661 pub fn evaluate(&self, parameters: &[f64]) -> LadduResult<Vec<Complex64>> {
4664 #[cfg(feature = "mpi")]
4665 {
4666 if let Some(world) = crate::mpi::get_world() {
4667 return self.evaluate_mpi(parameters, &world);
4668 }
4669 }
4670 self.evaluate_local(parameters)
4671 }
4672
4673 #[cfg(feature = "execution-context-prototype")]
4679 pub fn evaluate_with_ctx(
4680 &self,
4681 parameters: &[f64],
4682 execution_context: &ExecutionContext,
4683 ) -> Vec<Complex64> {
4684 #[cfg(feature = "mpi")]
4685 {
4686 if let Some(world) = crate::mpi::get_world() {
4687 return self.evaluate_mpi_with_ctx(parameters, &world, execution_context);
4688 }
4689 }
4690 self.evaluate_local_with_ctx(parameters, execution_context)
4691 }
4692
4693 pub fn evaluate_batch_local(
4696 &self,
4697 parameters: &[f64],
4698 indices: &[usize],
4699 ) -> LadduResult<Vec<Complex64>> {
4700 let resources = self.resources.read();
4701 let parameters = resources.parameter_map.assemble(parameters)?;
4702 let amplitude_len = self.amplitudes.len();
4703 let active_indices = resources.active_indices().to_vec();
4704 let slot_count = self.expression_value_slot_count();
4705 let program_snapshot = self.expression_value_program_snapshot();
4706 #[cfg(feature = "rayon")]
4707 {
4708 Ok(indices
4709 .par_iter()
4710 .map_init(
4711 || {
4712 (
4713 vec![Complex64::ZERO; amplitude_len],
4714 vec![Complex64::ZERO; slot_count],
4715 )
4716 },
4717 |(amplitude_values, expr_slots), &idx| {
4718 let cache = &resources.caches[idx];
4719 self.fill_amplitude_values(
4720 amplitude_values,
4721 &active_indices,
4722 ¶meters,
4723 cache,
4724 );
4725 self.evaluate_expression_value_with_program_snapshot(
4726 &program_snapshot,
4727 amplitude_values,
4728 expr_slots,
4729 )
4730 },
4731 )
4732 .collect())
4733 }
4734 #[cfg(not(feature = "rayon"))]
4735 {
4736 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4737 let mut expr_slots = vec![Complex64::ZERO; slot_count];
4738 Ok(indices
4739 .iter()
4740 .map(|&idx| {
4741 let cache = &resources.caches[idx];
4742 self.fill_amplitude_values(
4743 &mut amplitude_values,
4744 &active_indices,
4745 ¶meters,
4746 cache,
4747 );
4748 self.evaluate_expression_value_with_program_snapshot(
4749 &program_snapshot,
4750 &litude_values,
4751 &mut expr_slots,
4752 )
4753 })
4754 .collect())
4755 }
4756 }
4757
4758 #[cfg(feature = "mpi")]
4761 fn evaluate_batch_mpi(
4762 &self,
4763 parameters: &[f64],
4764 indices: &[usize],
4765 world: &SimpleCommunicator,
4766 ) -> LadduResult<Vec<Complex64>> {
4767 let total = self.dataset.n_events();
4768 let locals = world.locals_from_globals(indices, total);
4769 let local_evaluation = self.evaluate_batch_local(parameters, &locals)?;
4770 Ok(world.all_gather_batched_partitioned(&local_evaluation, indices, total, None))
4771 }
4772
4773 pub fn evaluate_batch(
4776 &self,
4777 parameters: &[f64],
4778 indices: &[usize],
4779 ) -> LadduResult<Vec<Complex64>> {
4780 #[cfg(feature = "mpi")]
4781 {
4782 if let Some(world) = crate::mpi::get_world() {
4783 return self.evaluate_batch_mpi(parameters, indices, &world);
4784 }
4785 }
4786 self.evaluate_batch_local(parameters, indices)
4787 }
4788
4789 pub fn evaluate_gradient_local(
4797 &self,
4798 parameters: &[f64],
4799 ) -> LadduResult<Vec<DVector<Complex64>>> {
4800 let resources = self.resources.read();
4801 let parameters = resources.parameter_map.assemble(parameters)?;
4802 let amplitude_len = self.amplitudes.len();
4803 let grad_dim = parameters.len();
4804 let active_indices = resources.active_indices().to_vec();
4805 let lowered_runtime = self.lowered_runtime();
4806 let gradient_program = lowered_runtime.gradient_program();
4807 let slot_count = self.expression_gradient_slot_count();
4808 #[cfg(feature = "rayon")]
4809 {
4810 Ok(resources
4811 .caches
4812 .par_iter()
4813 .map_init(
4814 || {
4815 (
4816 vec![Complex64::ZERO; amplitude_len],
4817 vec![DVector::zeros(grad_dim); amplitude_len],
4818 vec![Complex64::ZERO; slot_count],
4819 vec![Complex64::ZERO; slot_count * grad_dim],
4820 )
4821 },
4822 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
4823 self.fill_amplitude_values_and_gradients(
4824 amplitude_values,
4825 gradient_values,
4826 &active_indices,
4827 &resources.active,
4828 ¶meters,
4829 cache,
4830 );
4831 gradient_program.evaluate_gradient_into_flat(
4832 amplitude_values,
4833 gradient_values,
4834 value_slots,
4835 gradient_slots,
4836 grad_dim,
4837 )
4838 },
4839 )
4840 .collect())
4841 }
4842 #[cfg(not(feature = "rayon"))]
4843 {
4844 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
4845 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
4846 let mut value_slots = vec![Complex64::ZERO; slot_count];
4847 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
4848 Ok(resources
4849 .caches
4850 .iter()
4851 .map(|cache| {
4852 self.fill_amplitude_values_and_gradients(
4853 &mut amplitude_values,
4854 &mut gradient_values,
4855 &active_indices,
4856 &resources.active,
4857 ¶meters,
4858 cache,
4859 );
4860 gradient_program.evaluate_gradient_into_flat(
4861 &litude_values,
4862 &gradient_values,
4863 &mut value_slots,
4864 &mut gradient_slots,
4865 grad_dim,
4866 )
4867 })
4868 .collect())
4869 }
4870 }
4871
4872 #[cfg(feature = "execution-context-prototype")]
4874 pub fn evaluate_gradient_local_with_ctx(
4875 &self,
4876 parameters: &[f64],
4877 execution_context: &ExecutionContext,
4878 ) -> Vec<DVector<Complex64>> {
4879 let resources = self.resources.read();
4880 let parameters = resources
4881 .parameter_map
4882 .assemble(parameters)
4883 .expect("parameter slice must match evaluator resources");
4884 let amplitude_len = self.amplitudes.len();
4885 let grad_dim = parameters.len();
4886 let active_indices = resources.active_indices().to_vec();
4887 let slot_count = self.expression_slot_count();
4888 #[cfg(feature = "rayon")]
4889 {
4890 if !matches!(execution_context.thread_policy(), ThreadPolicy::Single) {
4891 return execution_context.install(|| {
4892 resources
4893 .caches
4894 .par_iter()
4895 .map_init(
4896 || {
4897 (
4898 vec![Complex64::ZERO; amplitude_len],
4899 vec![DVector::zeros(grad_dim); amplitude_len],
4900 vec![Complex64::ZERO; slot_count],
4901 vec![DVector::zeros(grad_dim); slot_count],
4902 )
4903 },
4904 |(amplitude_values, gradient_values, value_slots, gradient_slots),
4905 cache| {
4906 self.evaluate_cache_gradient_with_scratch(
4907 amplitude_values,
4908 gradient_values,
4909 value_slots,
4910 gradient_slots,
4911 &active_indices,
4912 &resources.active,
4913 ¶meters,
4914 cache,
4915 )
4916 },
4917 )
4918 .collect()
4919 });
4920 }
4921 }
4922 execution_context.with_scratch(|scratch| {
4923 let (amplitude_values, value_slots, gradient_values, gradient_slots) =
4924 scratch.reserve_gradient_workspaces(amplitude_len, slot_count, grad_dim);
4925 resources
4926 .caches
4927 .iter()
4928 .map(|cache| {
4929 self.evaluate_cache_gradient_with_scratch(
4930 amplitude_values,
4931 gradient_values,
4932 value_slots,
4933 gradient_slots,
4934 &active_indices,
4935 &resources.active,
4936 ¶meters,
4937 cache,
4938 )
4939 })
4940 .collect()
4941 })
4942 }
4943
4944 #[cfg(feature = "mpi")]
4952 fn evaluate_gradient_mpi(
4953 &self,
4954 parameters: &[f64],
4955 world: &SimpleCommunicator,
4956 ) -> LadduResult<Vec<DVector<Complex64>>> {
4957 let local_evaluation = self.evaluate_gradient_local(parameters)?;
4958 let n_events = self.dataset.n_events();
4959 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
4960 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
4961 {
4962 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4965 world.all_gather_varcount_into(
4966 &local_evaluation
4967 .iter()
4968 .flat_map(|v| v.data.as_vec())
4969 .copied()
4970 .collect::<Vec<_>>(),
4971 &mut partitioned_buffer,
4972 );
4973 }
4974 Ok(buffer
4975 .chunks(parameters.len())
4976 .map(DVector::from_row_slice)
4977 .collect())
4978 }
4979
4980 #[cfg(all(feature = "mpi", feature = "execution-context-prototype"))]
4981 fn evaluate_gradient_mpi_with_ctx(
4982 &self,
4983 parameters: &[f64],
4984 world: &SimpleCommunicator,
4985 execution_context: &ExecutionContext,
4986 ) -> Vec<DVector<Complex64>> {
4987 let local_evaluation = self.evaluate_gradient_local_with_ctx(parameters, execution_context);
4988 let n_events = self.dataset.n_events();
4989 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
4990 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
4991 {
4992 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
4995 world.all_gather_varcount_into(
4996 &local_evaluation
4997 .iter()
4998 .flat_map(|v| v.data.as_vec())
4999 .copied()
5000 .collect::<Vec<_>>(),
5001 &mut partitioned_buffer,
5002 );
5003 }
5004 buffer
5005 .chunks(parameters.len())
5006 .map(DVector::from_row_slice)
5007 .collect()
5008 }
5009
5010 pub fn evaluate_gradient(&self, parameters: &[f64]) -> LadduResult<Vec<DVector<Complex64>>> {
5013 #[cfg(feature = "mpi")]
5014 {
5015 if let Some(world) = crate::mpi::get_world() {
5016 return self.evaluate_gradient_mpi(parameters, &world);
5017 }
5018 }
5019 self.evaluate_gradient_local(parameters)
5020 }
5021
5022 #[cfg(feature = "execution-context-prototype")]
5028 pub fn evaluate_gradient_with_ctx(
5029 &self,
5030 parameters: &[f64],
5031 execution_context: &ExecutionContext,
5032 ) -> Vec<DVector<Complex64>> {
5033 #[cfg(feature = "mpi")]
5034 {
5035 if let Some(world) = crate::mpi::get_world() {
5036 return self.evaluate_gradient_mpi_with_ctx(parameters, &world, execution_context);
5037 }
5038 }
5039 self.evaluate_gradient_local_with_ctx(parameters, execution_context)
5040 }
5041
5042 pub fn evaluate_gradient_batch_local(
5045 &self,
5046 parameters: &[f64],
5047 indices: &[usize],
5048 ) -> LadduResult<Vec<DVector<Complex64>>> {
5049 let resources = self.resources.read();
5050 let parameters = resources.parameter_map.assemble(parameters)?;
5051 let amplitude_len = self.amplitudes.len();
5052 let grad_dim = parameters.len();
5053 let active_indices = resources.active_indices().to_vec();
5054 let lowered_runtime = self.lowered_runtime();
5055 let gradient_program = lowered_runtime.gradient_program();
5056 let slot_count = self.expression_gradient_slot_count();
5057 #[cfg(feature = "rayon")]
5058 {
5059 Ok(indices
5060 .par_iter()
5061 .map_init(
5062 || {
5063 (
5064 vec![Complex64::ZERO; amplitude_len],
5065 vec![DVector::zeros(grad_dim); amplitude_len],
5066 vec![Complex64::ZERO; slot_count],
5067 vec![Complex64::ZERO; slot_count * grad_dim],
5068 )
5069 },
5070 |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
5071 let cache = &resources.caches[idx];
5072 self.fill_amplitude_values_and_gradients(
5073 amplitude_values,
5074 gradient_values,
5075 &active_indices,
5076 &resources.active,
5077 ¶meters,
5078 cache,
5079 );
5080 gradient_program.evaluate_gradient_into_flat(
5081 amplitude_values,
5082 gradient_values,
5083 value_slots,
5084 gradient_slots,
5085 grad_dim,
5086 )
5087 },
5088 )
5089 .collect())
5090 }
5091 #[cfg(not(feature = "rayon"))]
5092 {
5093 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
5094 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
5095 let mut value_slots = vec![Complex64::ZERO; slot_count];
5096 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
5097 Ok(indices
5098 .iter()
5099 .map(|&idx| {
5100 let cache = &resources.caches[idx];
5101 self.fill_amplitude_values_and_gradients(
5102 &mut amplitude_values,
5103 &mut gradient_values,
5104 &active_indices,
5105 &resources.active,
5106 ¶meters,
5107 cache,
5108 );
5109 gradient_program.evaluate_gradient_into_flat(
5110 &litude_values,
5111 &gradient_values,
5112 &mut value_slots,
5113 &mut gradient_slots,
5114 grad_dim,
5115 )
5116 })
5117 .collect())
5118 }
5119 }
5120
5121 #[cfg(feature = "mpi")]
5124 fn evaluate_gradient_batch_mpi(
5125 &self,
5126 parameters: &[f64],
5127 indices: &[usize],
5128 world: &SimpleCommunicator,
5129 ) -> LadduResult<Vec<DVector<Complex64>>> {
5130 let total = self.dataset.n_events();
5131 let locals = world.locals_from_globals(indices, total);
5132 let flattened_local_evaluation = self
5133 .evaluate_gradient_batch_local(parameters, &locals)?
5134 .iter()
5135 .flat_map(|g| g.data.as_vec().to_vec())
5136 .collect::<Vec<Complex64>>();
5137 Ok(world
5138 .all_gather_batched_partitioned(
5139 &flattened_local_evaluation,
5140 indices,
5141 total,
5142 Some(parameters.len()),
5143 )
5144 .chunks(parameters.len())
5145 .map(DVector::from_row_slice)
5146 .collect())
5147 }
5148
5149 pub fn evaluate_gradient_batch(
5153 &self,
5154 parameters: &[f64],
5155 indices: &[usize],
5156 ) -> LadduResult<Vec<DVector<Complex64>>> {
5157 #[cfg(feature = "mpi")]
5158 {
5159 if let Some(world) = crate::mpi::get_world() {
5160 return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
5161 }
5162 }
5163 self.evaluate_gradient_batch_local(parameters, indices)
5164 }
5165
5166 pub fn evaluate_with_gradient_local(
5168 &self,
5169 parameters: &[f64],
5170 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
5171 let resources = self.resources.read();
5172 let parameters = resources.parameter_map.assemble(parameters)?;
5173 let amplitude_len = self.amplitudes.len();
5174 let grad_dim = parameters.len();
5175 let active_indices = resources.active_indices().to_vec();
5176 let lowered_runtime = self.lowered_runtime();
5177 let value_gradient_program = lowered_runtime.value_gradient_program();
5178 let slot_count = self.expression_value_gradient_slot_count();
5179 #[cfg(feature = "rayon")]
5180 {
5181 Ok(resources
5182 .caches
5183 .par_iter()
5184 .map_init(
5185 || {
5186 (
5187 vec![Complex64::ZERO; amplitude_len],
5188 vec![DVector::zeros(grad_dim); amplitude_len],
5189 vec![Complex64::ZERO; slot_count],
5190 vec![Complex64::ZERO; slot_count * grad_dim],
5191 )
5192 },
5193 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
5194 self.fill_amplitude_values_and_gradients(
5195 amplitude_values,
5196 gradient_values,
5197 &active_indices,
5198 &resources.active,
5199 ¶meters,
5200 cache,
5201 );
5202 value_gradient_program.evaluate_value_gradient_into_flat(
5203 amplitude_values,
5204 gradient_values,
5205 value_slots,
5206 gradient_slots,
5207 grad_dim,
5208 )
5209 },
5210 )
5211 .collect())
5212 }
5213 #[cfg(not(feature = "rayon"))]
5214 {
5215 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
5216 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
5217 let mut value_slots = vec![Complex64::ZERO; slot_count];
5218 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
5219 Ok(resources
5220 .caches
5221 .iter()
5222 .map(|cache| {
5223 self.fill_amplitude_values_and_gradients(
5224 &mut amplitude_values,
5225 &mut gradient_values,
5226 &active_indices,
5227 &resources.active,
5228 ¶meters,
5229 cache,
5230 );
5231 value_gradient_program.evaluate_value_gradient_into_flat(
5232 &litude_values,
5233 &gradient_values,
5234 &mut value_slots,
5235 &mut gradient_slots,
5236 grad_dim,
5237 )
5238 })
5239 .collect())
5240 }
5241 }
5242
5243 pub fn evaluate_with_gradient_local_with_active_mask(
5245 &self,
5246 parameters: &[f64],
5247 active_mask: &[bool],
5248 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
5249 let resources = self.resources.read();
5250 if active_mask.len() != resources.active.len() {
5251 return Err(LadduError::LengthMismatch {
5252 context: "active amplitude mask".to_string(),
5253 expected: resources.active.len(),
5254 actual: active_mask.len(),
5255 });
5256 }
5257 let parameters = resources.parameter_map.assemble(parameters)?;
5258 let amplitude_len = self.amplitudes.len();
5259 let grad_dim = parameters.len();
5260 let active_indices = active_mask
5261 .iter()
5262 .enumerate()
5263 .filter_map(|(index, &active)| if active { Some(index) } else { None })
5264 .collect::<Vec<_>>();
5265 let lowered_runtime = self.lower_expression_runtime_for_active_mask(active_mask)?;
5266 let slot_count = lowered_runtime.value_gradient_program().scratch_slots();
5267 #[cfg(feature = "rayon")]
5268 {
5269 Ok(resources
5270 .caches
5271 .par_iter()
5272 .map_init(
5273 || {
5274 (
5275 vec![Complex64::ZERO; amplitude_len],
5276 vec![DVector::zeros(grad_dim); amplitude_len],
5277 vec![Complex64::ZERO; slot_count],
5278 vec![Complex64::ZERO; slot_count * grad_dim],
5279 )
5280 },
5281 |(amplitude_values, gradient_values, value_slots, gradient_slots), cache| {
5282 self.fill_amplitude_values_and_gradients(
5283 amplitude_values,
5284 gradient_values,
5285 &active_indices,
5286 active_mask,
5287 ¶meters,
5288 cache,
5289 );
5290 lowered_runtime
5291 .value_gradient_program()
5292 .evaluate_value_gradient_into_flat(
5293 amplitude_values,
5294 gradient_values,
5295 value_slots,
5296 gradient_slots,
5297 grad_dim,
5298 )
5299 },
5300 )
5301 .collect())
5302 }
5303 #[cfg(not(feature = "rayon"))]
5304 {
5305 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
5306 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
5307 let mut value_slots = vec![Complex64::ZERO; slot_count];
5308 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
5309 Ok(resources
5310 .caches
5311 .iter()
5312 .map(|cache| {
5313 self.fill_amplitude_values_and_gradients(
5314 &mut amplitude_values,
5315 &mut gradient_values,
5316 &active_indices,
5317 active_mask,
5318 ¶meters,
5319 cache,
5320 );
5321 lowered_runtime
5322 .value_gradient_program()
5323 .evaluate_value_gradient_into_flat(
5324 &litude_values,
5325 &gradient_values,
5326 &mut value_slots,
5327 &mut gradient_slots,
5328 grad_dim,
5329 )
5330 })
5331 .collect())
5332 }
5333 }
5334
5335 pub fn evaluate_with_gradient_batch_local(
5337 &self,
5338 parameters: &[f64],
5339 indices: &[usize],
5340 ) -> LadduResult<Vec<(Complex64, DVector<Complex64>)>> {
5341 let resources = self.resources.read();
5342 let parameters = resources.parameter_map.assemble(parameters)?;
5343 let amplitude_len = self.amplitudes.len();
5344 let grad_dim = parameters.len();
5345 let active_indices = resources.active_indices().to_vec();
5346 let lowered_runtime = self.lowered_runtime();
5347 let value_gradient_program = lowered_runtime.value_gradient_program();
5348 let slot_count = self.expression_value_gradient_slot_count();
5349 #[cfg(feature = "rayon")]
5350 {
5351 Ok(indices
5352 .par_iter()
5353 .map_init(
5354 || {
5355 (
5356 vec![Complex64::ZERO; amplitude_len],
5357 vec![DVector::zeros(grad_dim); amplitude_len],
5358 vec![Complex64::ZERO; slot_count],
5359 vec![Complex64::ZERO; slot_count * grad_dim],
5360 )
5361 },
5362 |(amplitude_values, gradient_values, value_slots, gradient_slots), &idx| {
5363 let cache = &resources.caches[idx];
5364 self.fill_amplitude_values_and_gradients(
5365 amplitude_values,
5366 gradient_values,
5367 &active_indices,
5368 &resources.active,
5369 ¶meters,
5370 cache,
5371 );
5372 value_gradient_program.evaluate_value_gradient_into_flat(
5373 amplitude_values,
5374 gradient_values,
5375 value_slots,
5376 gradient_slots,
5377 grad_dim,
5378 )
5379 },
5380 )
5381 .collect())
5382 }
5383 #[cfg(not(feature = "rayon"))]
5384 {
5385 let mut amplitude_values = vec![Complex64::ZERO; amplitude_len];
5386 let mut gradient_values = vec![DVector::zeros(grad_dim); amplitude_len];
5387 let mut value_slots = vec![Complex64::ZERO; slot_count];
5388 let mut gradient_slots = vec![Complex64::ZERO; slot_count * grad_dim];
5389 Ok(indices
5390 .iter()
5391 .map(|&idx| {
5392 let cache = &resources.caches[idx];
5393 self.fill_amplitude_values_and_gradients(
5394 &mut amplitude_values,
5395 &mut gradient_values,
5396 &active_indices,
5397 &resources.active,
5398 ¶meters,
5399 cache,
5400 );
5401 value_gradient_program.evaluate_value_gradient_into_flat(
5402 &litude_values,
5403 &gradient_values,
5404 &mut value_slots,
5405 &mut gradient_slots,
5406 grad_dim,
5407 )
5408 })
5409 .collect())
5410 }
5411 }
5412}
5413
5414#[derive(Clone, Serialize, Deserialize)]
5416pub struct TestAmplitude {
5417 name: String,
5418 re: Parameter,
5419 pid_re: ParameterID,
5420 im: Parameter,
5421 pid_im: ParameterID,
5422 beam_energy: crate::ScalarID,
5423}
5424
5425impl TestAmplitude {
5426 #[allow(clippy::new_ret_no_self)]
5428 pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
5429 Self {
5430 name: name.to_string(),
5431 re,
5432 pid_re: Default::default(),
5433 im,
5434 pid_im: Default::default(),
5435 beam_energy: Default::default(),
5436 }
5437 .into_expression()
5438 }
5439}
5440
5441#[typetag::serde]
5442impl Amplitude for TestAmplitude {
5443 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
5444 self.pid_re = resources.register_parameter(&self.re)?;
5445 self.pid_im = resources.register_parameter(&self.im)?;
5446 self.beam_energy = resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
5447 resources.register_amplitude(&self.name)
5448 }
5449
5450 fn precompute(&self, event: &NamedEventView<'_>, cache: &mut Cache) {
5451 let beam = event.p4_at(0);
5452 cache.store_scalar(self.beam_energy, beam.e());
5453 }
5454
5455 fn compute(&self, parameters: &Parameters, cache: &Cache) -> Complex64 {
5456 Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
5457 * cache.get_scalar(self.beam_energy)
5458 }
5459
5460 fn compute_gradient(
5461 &self,
5462 parameters: &Parameters,
5463 cache: &Cache,
5464 gradient: &mut DVector<Complex64>,
5465 ) {
5466 let beam_energy = cache.get_scalar(self.beam_energy);
5467 if let Some(ind) = parameters.free_index(self.pid_re) {
5468 gradient[ind] = Complex64::ONE * beam_energy;
5469 }
5470 if let Some(ind) = parameters.free_index(self.pid_im) {
5471 gradient[ind] = Complex64::I * beam_energy;
5472 }
5473 }
5474}
5475
5476#[cfg(test)]
5477mod tests {
5478 use crate::data::{test_dataset, test_event, DatasetMetadata, EventData};
5479
5480 use super::*;
5481 use crate::resources::{Cache, ParameterID, Parameters, Resources, ScalarID};
5482 use crate::utils::vectors::Vec4;
5483 use approx::assert_relative_eq;
5484 #[cfg(feature = "mpi")]
5485 use mpi_test::mpi_test;
5486 use serde::{Deserialize, Serialize};
5487
5488 #[derive(Clone, Serialize, Deserialize)]
5489 pub struct ComplexScalar {
5490 name: String,
5491 re: Parameter,
5492 pid_re: ParameterID,
5493 im: Parameter,
5494 pid_im: ParameterID,
5495 }
5496
5497 impl ComplexScalar {
5498 #[allow(clippy::new_ret_no_self)]
5499 pub fn new(name: &str, re: Parameter, im: Parameter) -> LadduResult<Expression> {
5500 Self {
5501 name: name.to_string(),
5502 re,
5503 pid_re: Default::default(),
5504 im,
5505 pid_im: Default::default(),
5506 }
5507 .into_expression()
5508 }
5509 }
5510
5511 #[typetag::serde]
5512 impl Amplitude for ComplexScalar {
5513 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
5514 self.pid_re = resources.register_parameter(&self.re)?;
5515 self.pid_im = resources.register_parameter(&self.im)?;
5516 resources.register_amplitude(&self.name)
5517 }
5518
5519 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
5520 Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
5521 }
5522
5523 fn compute_gradient(
5524 &self,
5525 parameters: &Parameters,
5526 _cache: &Cache,
5527 gradient: &mut DVector<Complex64>,
5528 ) {
5529 if let Some(ind) = parameters.free_index(self.pid_re) {
5530 gradient[ind] = Complex64::ONE;
5531 }
5532 if let Some(ind) = parameters.free_index(self.pid_im) {
5533 gradient[ind] = Complex64::I;
5534 }
5535 }
5536 }
5537
5538 #[derive(Clone, Serialize, Deserialize)]
5539 pub struct ParameterOnlyScalar {
5540 name: String,
5541 value: Parameter,
5542 pid: ParameterID,
5543 }
5544
5545 impl ParameterOnlyScalar {
5546 #[allow(clippy::new_ret_no_self)]
5547 pub fn new(name: &str, value: Parameter) -> LadduResult<Expression> {
5548 Self {
5549 name: name.to_string(),
5550 value,
5551 pid: Default::default(),
5552 }
5553 .into_expression()
5554 }
5555 }
5556
5557 #[typetag::serde]
5558 impl Amplitude for ParameterOnlyScalar {
5559 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
5560 self.pid = resources.register_parameter(&self.value)?;
5561 resources.register_amplitude(&self.name)
5562 }
5563
5564 fn dependence_hint(&self) -> ExpressionDependence {
5565 ExpressionDependence::ParameterOnly
5566 }
5567
5568 fn real_valued_hint(&self) -> bool {
5569 true
5570 }
5571
5572 fn compute(&self, parameters: &Parameters, _cache: &Cache) -> Complex64 {
5573 Complex64::new(parameters.get(self.pid), 0.0)
5574 }
5575 }
5576
5577 #[derive(Clone, Serialize, Deserialize)]
5578 pub struct CacheOnlyScalar {
5579 name: String,
5580 beam_energy: ScalarID,
5581 }
5582
5583 impl CacheOnlyScalar {
5584 #[allow(clippy::new_ret_no_self)]
5585 pub fn new(name: &str) -> LadduResult<Expression> {
5586 Self {
5587 name: name.to_string(),
5588 beam_energy: Default::default(),
5589 }
5590 .into_expression()
5591 }
5592 }
5593
5594 #[typetag::serde]
5595 impl Amplitude for CacheOnlyScalar {
5596 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
5597 self.beam_energy =
5598 resources.register_scalar(Some(&format!("{}.beam_energy", self.name)));
5599 resources.register_amplitude(&self.name)
5600 }
5601
5602 fn dependence_hint(&self) -> ExpressionDependence {
5603 ExpressionDependence::CacheOnly
5604 }
5605
5606 fn real_valued_hint(&self) -> bool {
5607 true
5608 }
5609
5610 fn precompute(&self, event: &NamedEventView<'_>, cache: &mut Cache) {
5611 cache.store_scalar(self.beam_energy, event.p4_at(0).e());
5612 }
5613
5614 fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
5615 Complex64::new(cache.get_scalar(self.beam_energy), 0.0)
5616 }
5617 }
5618
5619 #[derive(Clone, Copy)]
5620 enum DeterministicFixtureKind {
5621 Separable,
5622 Partial,
5623 NonSeparable,
5624 }
5625
5626 struct DeterministicFixture {
5627 expression: Expression,
5628 dataset: Arc<Dataset>,
5629 parameters: Vec<f64>,
5630 }
5631
5632 const DETERMINISTIC_STRICT_ABS_TOL: f64 = 1e-12;
5633 const DETERMINISTIC_STRICT_REL_TOL: f64 = 1e-10;
5634
5635 fn deterministic_fixture_dataset() -> Arc<Dataset> {
5636 let metadata = Arc::new(DatasetMetadata::default());
5637 let events = vec![
5638 Arc::new(EventData {
5639 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
5640 aux: vec![],
5641 weight: 0.5,
5642 }),
5643 Arc::new(EventData {
5644 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
5645 aux: vec![],
5646 weight: -1.25,
5647 }),
5648 Arc::new(EventData {
5649 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
5650 aux: vec![],
5651 weight: 2.0,
5652 }),
5653 Arc::new(EventData {
5654 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
5655 aux: vec![],
5656 weight: 0.75,
5657 }),
5658 ];
5659 Arc::new(Dataset::new_with_metadata(events, metadata))
5660 }
5661
5662 fn make_deterministic_fixture(kind: DeterministicFixtureKind) -> DeterministicFixture {
5663 let dataset = deterministic_fixture_dataset();
5664 match kind {
5665 DeterministicFixtureKind::Separable => {
5666 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1"))
5667 .expect("separable p1 should build");
5668 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2"))
5669 .expect("separable p2 should build");
5670 let c1 = CacheOnlyScalar::new("c1").expect("separable c1 should build");
5671 let c2 = CacheOnlyScalar::new("c2").expect("separable c2 should build");
5672 DeterministicFixture {
5673 expression: (&p1 * &c1) + &(&p2 * &c2),
5674 dataset,
5675 parameters: vec![0.4, -0.3],
5676 }
5677 }
5678 DeterministicFixtureKind::Partial => {
5679 let p =
5680 ParameterOnlyScalar::new("p", parameter!("p")).expect("partial p should build");
5681 let c = CacheOnlyScalar::new("c").expect("partial c should build");
5682 let m = TestAmplitude::new("m", parameter!("mr"), parameter!("mi"))
5683 .expect("partial m should build");
5684 DeterministicFixture {
5685 expression: (&p * &c) + &m,
5686 dataset,
5687 parameters: vec![0.55, 0.2, -0.15],
5688 }
5689 }
5690 DeterministicFixtureKind::NonSeparable => {
5691 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i"))
5692 .expect("non-separable m1 should build");
5693 let m2 = TestAmplitude::new("m2", parameter!("m2r"), parameter!("m2i"))
5694 .expect("non-separable m2 should build");
5695 DeterministicFixture {
5696 expression: &m1 * &m2,
5697 dataset,
5698 parameters: vec![0.25, -0.4, 0.6, 0.1],
5699 }
5700 }
5701 }
5702 }
5703
5704 fn assert_weighted_sum_matches_eventwise_baseline(fixture: &DeterministicFixture) {
5705 let evaluator = fixture
5706 .expression
5707 .load(&fixture.dataset)
5708 .expect("fixture evaluator should load");
5709 let expected_value = evaluator
5710 .evaluate_local(&fixture.parameters)
5711 .expect("evaluation should succeed")
5712 .iter()
5713 .zip(fixture.dataset.events_local().iter())
5714 .fold(0.0, |accum, (value, event)| {
5715 accum + event.weight() * value.re
5716 });
5717 let expected_gradient = evaluator
5718 .evaluate_gradient_local(&fixture.parameters)
5719 .expect("evaluation should succeed")
5720 .iter()
5721 .zip(fixture.dataset.events_local().iter())
5722 .fold(
5723 DVector::zeros(fixture.parameters.len()),
5724 |mut accum, (gradient, event)| {
5725 accum += gradient.map(|value| value.re).scale(event.weight());
5726 accum
5727 },
5728 );
5729 let actual_value = evaluator
5730 .evaluate_weighted_value_sum_local(&fixture.parameters)
5731 .expect("evaluation should succeed");
5732 let actual_gradient = evaluator
5733 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
5734 .expect("evaluation should succeed");
5735 assert_relative_eq!(
5736 actual_value,
5737 expected_value,
5738 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5739 max_relative = DETERMINISTIC_STRICT_REL_TOL
5740 );
5741 for (actual_item, expected_item) in actual_gradient.iter().zip(expected_gradient.iter()) {
5742 assert_relative_eq!(
5743 *actual_item,
5744 *expected_item,
5745 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5746 max_relative = DETERMINISTIC_STRICT_REL_TOL
5747 );
5748 }
5749 }
5750 fn assert_mixed_normalization_components_match_combined_path(fixture: &DeterministicFixture) {
5751 let evaluator = fixture
5752 .expression
5753 .load(&fixture.dataset)
5754 .expect("fixture evaluator should load");
5755 let state = {
5756 let resources = evaluator.resources.read();
5757 evaluator.ensure_cached_integral_cache_state(&resources)
5758 }
5759 .expect("state should be available");
5760 assert!(
5761 !state.values.is_empty(),
5762 "fixture should exercise cached normalization terms"
5763 );
5764 assert!(
5765 !state.execution_sets.residual_amplitudes.is_empty(),
5766 "fixture should exercise residual normalization amplitudes"
5767 );
5768
5769 let (residual_value_sum, cached_value_sum) = evaluator
5770 .evaluate_weighted_value_sum_local_components(&fixture.parameters)
5771 .expect("evaluation should succeed");
5772 assert!(residual_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
5773 assert!(cached_value_sum.abs() > DETERMINISTIC_STRICT_ABS_TOL);
5774 let combined_value = evaluator
5775 .evaluate_weighted_value_sum_local(&fixture.parameters)
5776 .expect("evaluation should succeed");
5777 assert_relative_eq!(
5778 residual_value_sum + cached_value_sum,
5779 combined_value,
5780 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5781 max_relative = DETERMINISTIC_STRICT_REL_TOL
5782 );
5783
5784 let (residual_gradient_sum, cached_gradient_sum) = evaluator
5785 .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
5786 .expect("evaluation should succeed");
5787 let combined_gradient = evaluator
5788 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
5789 .expect("evaluation should succeed");
5790 assert!(residual_gradient_sum
5791 .iter()
5792 .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
5793 assert!(cached_gradient_sum
5794 .iter()
5795 .any(|value| value.abs() > DETERMINISTIC_STRICT_ABS_TOL));
5796 for ((residual_item, cached_item), combined_item) in residual_gradient_sum
5797 .iter()
5798 .zip(cached_gradient_sum.iter())
5799 .zip(combined_gradient.iter())
5800 {
5801 assert_relative_eq!(
5802 residual_item + cached_item,
5803 *combined_item,
5804 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5805 max_relative = DETERMINISTIC_STRICT_REL_TOL
5806 );
5807 }
5808 }
5809
5810 #[test]
5811 fn test_deterministic_fixture_weighted_sums_stable_across_activation_mask_toggle() {
5812 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5813 let evaluator = fixture
5814 .expression
5815 .load(&fixture.dataset)
5816 .expect("fixture evaluator should load");
5817 let original_mask = evaluator.active_mask();
5818
5819 let original_value = evaluator
5820 .evaluate_weighted_value_sum_local(&fixture.parameters)
5821 .expect("evaluation should succeed");
5822
5823 evaluator.isolate_many(&["p", "c"]);
5824 assert_ne!(evaluator.active_mask(), original_mask);
5825
5826 evaluator
5827 .set_active_mask(&original_mask)
5828 .expect("original fixture active mask should restore");
5829 assert_eq!(evaluator.active_mask(), original_mask);
5830 let actual_value = evaluator
5831 .evaluate_weighted_value_sum_local(&fixture.parameters)
5832 .expect("evaluation should succeed");
5833 assert_relative_eq!(
5834 actual_value,
5835 original_value,
5836 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5837 max_relative = DETERMINISTIC_STRICT_REL_TOL
5838 );
5839 }
5840
5841 #[test]
5842 fn test_deterministic_fixtures_match_eventwise_weighted_sums() {
5843 let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
5844 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5845 let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
5846
5847 assert_weighted_sum_matches_eventwise_baseline(&separable);
5848 assert_weighted_sum_matches_eventwise_baseline(&partial);
5849 assert_weighted_sum_matches_eventwise_baseline(&non_separable);
5850 }
5851 #[test]
5852 fn test_deterministic_fixtures_cover_separable_partial_non_separable_models() {
5853 let separable = make_deterministic_fixture(DeterministicFixtureKind::Separable);
5854 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5855 let non_separable = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
5856
5857 let separable_evaluator = separable
5858 .expression
5859 .load(&separable.dataset)
5860 .expect("separable evaluator should load");
5861 let partial_evaluator = partial
5862 .expression
5863 .load(&partial.dataset)
5864 .expect("partial evaluator should load");
5865 let non_separable_evaluator = non_separable
5866 .expression
5867 .load(&non_separable.dataset)
5868 .expect("non-separable evaluator should load");
5869
5870 assert_eq!(
5871 separable_evaluator
5872 .expression_precomputed_cached_integrals()
5873 .expect("integrals should be computed")
5874 .len(),
5875 2
5876 );
5877 assert_eq!(
5878 partial_evaluator
5879 .expression_precomputed_cached_integrals()
5880 .expect("integrals should be computed")
5881 .len(),
5882 1
5883 );
5884 assert!(non_separable_evaluator
5885 .expression_precomputed_cached_integrals()
5886 .expect("integrals should be computed")
5887 .is_empty());
5888 }
5889 #[test]
5890 fn test_partial_fixture_combined_normalization_components_match_total() {
5891 let partial = make_deterministic_fixture(DeterministicFixtureKind::Partial);
5892 assert_mixed_normalization_components_match_combined_path(&partial);
5893 }
5894 #[test]
5895 fn test_non_separable_fixture_normalization_components_stay_residual_only() {
5896 let fixture = make_deterministic_fixture(DeterministicFixtureKind::NonSeparable);
5897 let evaluator = fixture
5898 .expression
5899 .load(&fixture.dataset)
5900 .expect("fixture evaluator should load");
5901 let resources = evaluator.resources.read();
5902 let state = evaluator
5903 .ensure_cached_integral_cache_state(&resources)
5904 .expect("state should be available");
5905 assert!(state.values.is_empty());
5906
5907 let (residual_value_sum, cached_value_sum) = evaluator
5908 .evaluate_weighted_value_sum_local_components(&fixture.parameters)
5909 .expect("evaluation should succeed");
5910 assert_relative_eq!(
5911 cached_value_sum,
5912 0.0,
5913 epsilon = DETERMINISTIC_STRICT_ABS_TOL
5914 );
5915 assert_relative_eq!(
5916 residual_value_sum,
5917 evaluator
5918 .evaluate_weighted_value_sum_local(&fixture.parameters)
5919 .expect("evaluation should succeed"),
5920 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5921 max_relative = DETERMINISTIC_STRICT_REL_TOL
5922 );
5923
5924 let (residual_gradient_sum, cached_gradient_sum) = evaluator
5925 .evaluate_weighted_gradient_sum_local_components(&fixture.parameters)
5926 .expect("evaluation should succeed");
5927 assert!(cached_gradient_sum
5928 .iter()
5929 .all(|value| value.abs() <= DETERMINISTIC_STRICT_ABS_TOL));
5930 let combined_gradient = evaluator
5931 .evaluate_weighted_gradient_sum_local(&fixture.parameters)
5932 .expect("evaluation should succeed");
5933 for (residual_item, combined_item) in
5934 residual_gradient_sum.iter().zip(combined_gradient.iter())
5935 {
5936 assert_relative_eq!(
5937 *residual_item,
5938 *combined_item,
5939 epsilon = DETERMINISTIC_STRICT_ABS_TOL,
5940 max_relative = DETERMINISTIC_STRICT_REL_TOL
5941 );
5942 }
5943 }
5944
5945 #[test]
5946 fn test_batch_evaluation() {
5947 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag")).unwrap();
5948 let mut event1 = test_event();
5949 event1.p4s[0].t = 10.0;
5950 let mut event2 = test_event();
5951 event2.p4s[0].t = 11.0;
5952 let mut event3 = test_event();
5953 event3.p4s[0].t = 12.0;
5954 let dataset = Arc::new(Dataset::new_with_metadata(
5955 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
5956 Arc::new(DatasetMetadata::default()),
5957 ));
5958 let evaluator = expr.load(&dataset).unwrap();
5959 let result = evaluator
5960 .evaluate_batch(&[1.1, 2.2], &[0, 2])
5961 .expect("evaluation should succeed");
5962 assert_eq!(result.len(), 2);
5963 assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
5964 assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
5965 let result_grad = evaluator
5966 .evaluate_gradient_batch(&[1.1, 2.2], &[0, 2])
5967 .expect("evaluation should succeed");
5968 assert_eq!(result_grad.len(), 2);
5969 assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
5970 assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
5971 assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
5972 assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
5973 }
5974
5975 #[test]
5976 fn test_load_compiles_expression_ir_once() {
5977 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5978 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5979 .norm_sqr();
5980 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5981 let evaluator = expr.load(&dataset).unwrap();
5982 assert!(evaluator.expression_slot_count() > 0);
5983 }
5984 #[test]
5985 fn test_expression_ir_value_matches_lowered_runtime() {
5986 let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
5987 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
5988 * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
5989 .conj()
5990 .norm_sqr();
5991 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
5992 let evaluator = expr.load(&dataset).unwrap();
5993 let resources = evaluator.resources.read();
5994 let parameters = resources
5995 .parameter_map
5996 .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
5997 .expect("parameters should assemble");
5998 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
5999 evaluator.fill_amplitude_values(
6000 &mut amplitude_values,
6001 resources.active_indices(),
6002 ¶meters,
6003 &resources.caches[0],
6004 );
6005 let mut ir_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
6006 let lowered_runtime = evaluator.lowered_runtime();
6007 let lowered_program = lowered_runtime.value_program();
6008 let mut lowered_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
6009 let lowered_value =
6010 evaluator.evaluate_expression_value_with_scratch(&litude_values, &mut ir_slots);
6011 let direct_lowered_value =
6012 lowered_program.evaluate_into(&litude_values, &mut lowered_slots);
6013 let ir_value = evaluator
6014 .expression_ir()
6015 .evaluate_into(&litude_values, &mut ir_slots);
6016 assert_relative_eq!(lowered_value.re, direct_lowered_value.re);
6017 assert_relative_eq!(lowered_value.im, direct_lowered_value.im);
6018 assert_relative_eq!(lowered_value.re, ir_value.re);
6019 assert_relative_eq!(lowered_value.im, ir_value.im);
6020 }
6021 #[test]
6022 fn test_expression_ir_load_initializes_with_lowered_value_runtime() {
6023 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai"))
6024 .unwrap()
6025 .norm_sqr();
6026 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6027 let evaluator = expr.load(&dataset).unwrap();
6028 let lowered_runtime = evaluator.lowered_runtime();
6029 assert_eq!(
6030 lowered_runtime.value_program().kind(),
6031 lowered::LoweredProgramKind::Value
6032 );
6033 assert_eq!(
6034 lowered_runtime.gradient_program().kind(),
6035 lowered::LoweredProgramKind::Gradient
6036 );
6037 assert_eq!(
6038 lowered_runtime.value_gradient_program().kind(),
6039 lowered::LoweredProgramKind::ValueGradient
6040 );
6041 }
6042 #[test]
6043 fn test_expression_ir_gradient_matches_lowered_runtime() {
6044 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
6045 * TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
6046 .norm_sqr();
6047 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6048 let evaluator = expr.load(&dataset).unwrap();
6049 let resources = evaluator.resources.read();
6050 let parameters = resources
6051 .parameter_map
6052 .assemble(&[1.0, 0.25, -0.8, 0.5])
6053 .expect("parameters should assemble");
6054 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
6055 evaluator.fill_amplitude_values(
6056 &mut amplitude_values,
6057 resources.active_indices(),
6058 ¶meters,
6059 &resources.caches[0],
6060 );
6061 let mut active_mask = vec![false; evaluator.amplitudes.len()];
6062 for &index in resources.active_indices() {
6063 active_mask[index] = true;
6064 }
6065 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
6066 .map(|_| DVector::zeros(parameters.len()))
6067 .collect::<Vec<_>>();
6068 evaluator.fill_amplitude_gradients(
6069 &mut amplitude_gradients,
6070 &active_mask,
6071 ¶meters,
6072 &resources.caches[0],
6073 );
6074 let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
6075 let mut ir_gradient_slots: Vec<DVector<Complex64>> =
6076 (0..evaluator.expression_ir().node_count())
6077 .map(|_| DVector::zeros(parameters.len()))
6078 .collect();
6079 let lowered_runtime = evaluator.lowered_runtime();
6080 let lowered_program = lowered_runtime.gradient_program();
6081 let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
6082 let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
6083 .scratch_slots())
6084 .map(|_| DVector::zeros(parameters.len()))
6085 .collect();
6086 let active_gradient = evaluator.evaluate_expression_gradient_with_scratch(
6087 &litude_values,
6088 &litude_gradients,
6089 &mut ir_value_slots,
6090 &mut ir_gradient_slots,
6091 );
6092 let ir_gradient = evaluator.expression_ir().evaluate_gradient_into(
6093 &litude_values,
6094 &litude_gradients,
6095 &mut ir_value_slots,
6096 &mut ir_gradient_slots,
6097 );
6098 let lowered_gradient = lowered_program.evaluate_gradient_into(
6099 &litude_values,
6100 &litude_gradients,
6101 &mut lowered_value_slots,
6102 &mut lowered_gradient_slots,
6103 );
6104 for (active, lowered) in active_gradient.iter().zip(lowered_gradient.iter()) {
6105 assert_relative_eq!(active.re, lowered.re);
6106 assert_relative_eq!(active.im, lowered.im);
6107 }
6108 for (lowered, ir) in lowered_gradient.iter().zip(ir_gradient.iter()) {
6109 assert_relative_eq!(lowered.re, ir.re);
6110 assert_relative_eq!(lowered.im, ir.im);
6111 }
6112 }
6113 #[test]
6114 fn test_expression_ir_value_gradient_matches_lowered_runtime() {
6115 let expr = ((TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
6116 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
6117 * TestAmplitude::new("c", parameter!("cr"), parameter!("ci")).unwrap())
6118 .norm_sqr();
6119 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6120 let evaluator = expr.load(&dataset).unwrap();
6121 let resources = evaluator.resources.read();
6122 let parameters = resources
6123 .parameter_map
6124 .assemble(&[1.0, 0.25, -0.8, 0.5, 0.2, -1.1])
6125 .expect("parameters should assemble");
6126 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
6127 evaluator.fill_amplitude_values(
6128 &mut amplitude_values,
6129 resources.active_indices(),
6130 ¶meters,
6131 &resources.caches[0],
6132 );
6133 let mut active_mask = vec![false; evaluator.amplitudes.len()];
6134 for &index in resources.active_indices() {
6135 active_mask[index] = true;
6136 }
6137 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
6138 .map(|_| DVector::zeros(parameters.len()))
6139 .collect::<Vec<_>>();
6140 evaluator.fill_amplitude_gradients(
6141 &mut amplitude_gradients,
6142 &active_mask,
6143 ¶meters,
6144 &resources.caches[0],
6145 );
6146 let mut ir_value_slots = vec![Complex64::ZERO; evaluator.expression_ir().node_count()];
6147 let mut ir_gradient_slots: Vec<DVector<Complex64>> =
6148 (0..evaluator.expression_ir().node_count())
6149 .map(|_| DVector::zeros(parameters.len()))
6150 .collect();
6151 let lowered_runtime = evaluator.lowered_runtime();
6152 let lowered_program = lowered_runtime.value_gradient_program();
6153 let mut lowered_value_slots = vec![Complex64::ZERO; lowered_program.scratch_slots()];
6154 let mut lowered_gradient_slots: Vec<DVector<Complex64>> = (0..lowered_program
6155 .scratch_slots())
6156 .map(|_| DVector::zeros(parameters.len()))
6157 .collect();
6158
6159 let active_value_gradient = evaluator.evaluate_expression_value_gradient_with_scratch(
6160 &litude_values,
6161 &litude_gradients,
6162 &mut ir_value_slots,
6163 &mut ir_gradient_slots,
6164 );
6165 let ir_value_gradient = evaluator.expression_ir().evaluate_value_gradient_into(
6166 &litude_values,
6167 &litude_gradients,
6168 &mut ir_value_slots,
6169 &mut ir_gradient_slots,
6170 );
6171 let lowered_value_gradient = lowered_program.evaluate_value_gradient_into(
6172 &litude_values,
6173 &litude_gradients,
6174 &mut lowered_value_slots,
6175 &mut lowered_gradient_slots,
6176 );
6177
6178 assert_relative_eq!(active_value_gradient.0.re, lowered_value_gradient.0.re);
6179 assert_relative_eq!(active_value_gradient.0.im, lowered_value_gradient.0.im);
6180 for (active, lowered) in active_value_gradient
6181 .1
6182 .iter()
6183 .zip(lowered_value_gradient.1.iter())
6184 {
6185 assert_relative_eq!(active.re, lowered.re);
6186 assert_relative_eq!(active.im, lowered.im);
6187 }
6188 assert_relative_eq!(lowered_value_gradient.0.re, ir_value_gradient.0.re);
6189 assert_relative_eq!(lowered_value_gradient.0.im, ir_value_gradient.0.im);
6190 for (lowered, ir) in lowered_value_gradient
6191 .1
6192 .iter()
6193 .zip(ir_value_gradient.1.iter())
6194 {
6195 assert_relative_eq!(lowered.re, ir.re);
6196 assert_relative_eq!(lowered.im, ir.im);
6197 }
6198 }
6199 #[test]
6200 fn test_expression_runtime_diagnostics_reports_lowered_programs() {
6201 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
6202 let evaluator = fixture
6203 .expression
6204 .load(&fixture.dataset)
6205 .expect("fixture evaluator should load");
6206
6207 let diagnostics = evaluator.expression_runtime_diagnostics();
6208 assert!(diagnostics.ir_planning_enabled);
6209 assert!(diagnostics.lowered_value_program_present);
6210 assert!(diagnostics.lowered_gradient_program_present);
6211 assert!(diagnostics.lowered_value_gradient_program_present);
6212 assert!(diagnostics.residual_runtime_present);
6213 assert_eq!(
6214 diagnostics.specialization_status,
6215 Some(ExpressionSpecializationStatus {
6216 origin: ExpressionSpecializationOrigin::InitialLoad,
6217 })
6218 );
6219 }
6220 #[test]
6221 fn test_expression_runtime_diagnostics_reports_specialization_origin() {
6222 let fixture = make_deterministic_fixture(DeterministicFixtureKind::Partial);
6223 let evaluator = fixture
6224 .expression
6225 .load(&fixture.dataset)
6226 .expect("fixture evaluator should load");
6227
6228 assert_eq!(
6229 evaluator
6230 .expression_runtime_diagnostics()
6231 .specialization_status,
6232 Some(ExpressionSpecializationStatus {
6233 origin: ExpressionSpecializationOrigin::InitialLoad,
6234 })
6235 );
6236
6237 evaluator.isolate_many(&["p"]);
6238 assert_eq!(
6239 evaluator
6240 .expression_runtime_diagnostics()
6241 .specialization_status,
6242 Some(ExpressionSpecializationStatus {
6243 origin: ExpressionSpecializationOrigin::CacheMissRebuild,
6244 })
6245 );
6246 }
6247 #[test]
6248 fn test_compiled_expression_display_reports_dag_refs() {
6249 let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
6250 let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
6251 let term = &a * &b;
6252 let expr = &term + &term;
6253 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6254 let evaluator = expr.load(&dataset).unwrap();
6255
6256 let compiled = evaluator.compiled_expression();
6257 let display = compiled.to_string();
6258
6259 assert_eq!(compiled.root(), compiled.nodes().len() - 1);
6260 assert!(display.contains("#"));
6261 assert!(display.contains("+"));
6262 assert!(display.contains("×"));
6263 assert!(display.contains("a(id=0)"));
6264 assert!(display.contains("b(id=1)"));
6265 assert!(display.contains("(ref)"));
6266 }
6267
6268 #[test]
6269 fn test_expression_compiled_expression_display_reports_dag_refs_without_loading() {
6270 let a = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
6271 let b = TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
6272 let term = &a * &b;
6273 let expr = &term + &term;
6274
6275 let compiled = expr.compiled_expression();
6276 let display = compiled.to_string();
6277
6278 assert_eq!(compiled.root(), compiled.nodes().len() - 1);
6279 assert!(display.contains("#"));
6280 assert!(display.contains("+"));
6281 assert!(display.contains("×"));
6282 assert!(display.contains("a(id=0)"));
6283 assert!(display.contains("b(id=1)"));
6284 assert!(display.contains("(ref)"));
6285 }
6286
6287 #[test]
6288 fn test_compiled_expression_display_uses_current_active_mask() {
6289 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
6290 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap();
6291 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6292 let evaluator = expr.load(&dataset).unwrap();
6293 evaluator.deactivate("b");
6294
6295 let compiled = evaluator.compiled_expression().to_string();
6296
6297 assert!(compiled.contains("a(id=0)"));
6298 assert!(!compiled.contains("b(id=1)"));
6299 assert!(compiled.contains("const 0"));
6300 }
6301
6302 #[test]
6303 fn test_evaluator_expression_reconstructs_expression() {
6304 let expr = TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap();
6305 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6306 let evaluator = expr.load(&dataset).unwrap();
6307
6308 assert_eq!(
6309 evaluator.expression().compiled_expression(),
6310 expr.compiled_expression()
6311 );
6312 }
6313
6314 #[test]
6315 fn test_active_mask_override_ignores_current_ir_specialization() {
6316 let expr = ComplexScalar::new("amp", parameter!("scale"), parameter!("amp_im", 0.0))
6317 .unwrap()
6318 .norm_sqr();
6319 let dataset = Arc::new(test_dataset());
6320 let evaluator = expr.load(&dataset).unwrap();
6321 let params = vec![2.0];
6322
6323 evaluator.deactivate("amp");
6324 assert_eq!(
6325 evaluator
6326 .evaluate(¶ms)
6327 .expect("evaluation should succeed")[0],
6328 Complex64::new(0.0, 0.0)
6329 );
6330
6331 let overridden = evaluator
6332 .evaluate_local_with_active_mask(¶ms, &[true])
6333 .unwrap();
6334 assert_eq!(overridden[0], Complex64::new(4.0, 0.0));
6335
6336 let overridden_fused = evaluator
6337 .evaluate_with_gradient_local_with_active_mask(¶ms, &[true])
6338 .unwrap();
6339 assert_eq!(overridden_fused[0].0, Complex64::new(4.0, 0.0));
6340 assert_eq!(overridden_fused[0].1[0], Complex64::new(4.0, 0.0));
6341 }
6342 #[test]
6343 fn test_expression_ir_dependence_diagnostics_surface() {
6344 let expr = (TestAmplitude::new("a", parameter!("ar"), parameter!("ai")).unwrap()
6345 + TestAmplitude::new("b", parameter!("br"), parameter!("bi")).unwrap())
6346 .norm_sqr();
6347 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6348 let evaluator = expr.load(&dataset).unwrap();
6349 let annotations = evaluator
6350 .expression_node_dependence_annotations()
6351 .expect("annotations should exist");
6352 assert_eq!(annotations.len(), evaluator.expression_ir().node_count());
6353 assert!(annotations
6354 .iter()
6355 .all(|dependence| *dependence == ExpressionDependence::Mixed));
6356 assert_eq!(
6357 evaluator
6358 .expression_root_dependence()
6359 .expect("root dependence should exist"),
6360 ExpressionDependence::Mixed
6361 );
6362 }
6363 #[test]
6364 fn test_expression_ir_default_dependence_hint_is_mixed() {
6365 let expr = ComplexScalar::new("c", parameter!("cr"), parameter!("ci")).unwrap();
6366 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6367 let evaluator = expr.load(&dataset).unwrap();
6368 assert_eq!(
6369 evaluator
6370 .expression_root_dependence()
6371 .expect("root dependence should exist"),
6372 ExpressionDependence::Mixed
6373 );
6374 }
6375 #[test]
6376 fn test_expression_ir_parameter_only_dependence_hint_propagates() {
6377 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap();
6378 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6379 let evaluator = expr.load(&dataset).unwrap();
6380 assert_eq!(
6381 evaluator
6382 .expression_root_dependence()
6383 .expect("root dependence should exist"),
6384 ExpressionDependence::ParameterOnly
6385 );
6386 }
6387 #[test]
6388 fn test_expression_ir_cache_only_dependence_hint_propagates() {
6389 let expr = CacheOnlyScalar::new("k").unwrap();
6390 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6391 let evaluator = expr.load(&dataset).unwrap();
6392 assert_eq!(
6393 evaluator
6394 .expression_root_dependence()
6395 .expect("root dependence should exist"),
6396 ExpressionDependence::CacheOnly
6397 );
6398 }
6399 #[test]
6400 fn test_expression_ir_real_valued_hint_folds_imag_projection_to_zero() {
6401 let expr = ParameterOnlyScalar::new("p", parameter!("p"))
6402 .unwrap()
6403 .imag();
6404 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6405 let evaluator = expr.load(&dataset).unwrap();
6406 let ir = evaluator.expression_ir();
6407
6408 assert!(matches!(
6409 ir.nodes()[ir.root()],
6410 ir::IrNode::Constant(value) if value == Complex64::ZERO
6411 ));
6412 assert_eq!(
6413 evaluator
6414 .evaluate(&[2.5])
6415 .expect("evaluation should succeed")[0],
6416 Complex64::ZERO
6417 );
6418 }
6419 #[test]
6420 fn test_expression_ir_real_valued_hint_simplifies_conjugation() {
6421 let expr = ParameterOnlyScalar::new("p", parameter!("p"))
6422 .unwrap()
6423 .conj();
6424 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6425 let evaluator = expr.load(&dataset).unwrap();
6426 let ir = evaluator.expression_ir();
6427
6428 assert!(matches!(ir.nodes()[ir.root()], ir::IrNode::Amp(0)));
6429 assert_eq!(
6430 evaluator
6431 .evaluate(&[2.5])
6432 .expect("evaluation should succeed")[0],
6433 Complex64::new(2.5, 0.0)
6434 );
6435 }
6436 #[test]
6437 fn test_expression_ir_dependence_warnings_surface() {
6438 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6439 + &CacheOnlyScalar::new("k").unwrap();
6440 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6441 let evaluator = expr.load(&dataset).unwrap();
6442 assert!(evaluator
6443 .expression_dependence_warnings()
6444 .expect("warnings should exist")
6445 .iter()
6446 .any(|warning| warning.contains("both ParameterOnly and CacheOnly")));
6447 }
6448 #[test]
6449 fn test_expression_ir_normalization_plan_explain_surface() {
6450 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6451 * &CacheOnlyScalar::new("k").unwrap();
6452 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6453 let evaluator = expr.load(&dataset).unwrap();
6454 let explain = evaluator
6455 .expression_normalization_plan_explain()
6456 .expect("plan should exist");
6457 assert_eq!(explain.root_dependence, ExpressionDependence::Mixed);
6458 assert_eq!(explain.separable_mul_candidate_nodes.len(), 1);
6459 assert_eq!(
6460 explain.cached_separable_nodes,
6461 explain.separable_mul_candidate_nodes
6462 );
6463 assert!(explain.residual_terms.iter().all(|index| {
6464 !explain
6465 .separable_mul_candidate_nodes
6466 .iter()
6467 .any(|candidate| candidate == index)
6468 }));
6469 }
6470 #[test]
6471 fn test_expression_ir_normalization_execution_sets_surface() {
6472 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6473 * &CacheOnlyScalar::new("k").unwrap();
6474 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6475 let evaluator = expr.load(&dataset).unwrap();
6476 let sets = evaluator
6477 .expression_normalization_execution_sets()
6478 .expect("sets should exist");
6479 assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
6480 assert_eq!(sets.cached_cache_amplitudes, vec![1]);
6481 assert!(sets.residual_amplitudes.is_empty());
6482 }
6483 #[test]
6484 fn test_expression_ir_normalization_execution_sets_partial_surface() {
6485 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6486 * &CacheOnlyScalar::new("k").unwrap())
6487 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6488 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6489 let evaluator = expr.load(&dataset).unwrap();
6490 let sets = evaluator
6491 .expression_normalization_execution_sets()
6492 .expect("sets should exist");
6493 assert_eq!(sets.cached_parameter_amplitudes, vec![0]);
6494 assert_eq!(sets.cached_cache_amplitudes, vec![1]);
6495 assert_eq!(sets.residual_amplitudes, vec![2]);
6496 }
6497 #[test]
6498 fn test_expression_ir_precomputed_cached_integrals_at_load() {
6499 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6500 * &CacheOnlyScalar::new("k").unwrap();
6501 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6502 let evaluator = expr.load(&dataset).unwrap();
6503 let precomputed = evaluator
6504 .expression_precomputed_cached_integrals()
6505 .expect("integrals should exist");
6506 assert_eq!(precomputed.len(), 1);
6507 let cache_reference = CacheOnlyScalar::new("k_ref")
6508 .unwrap()
6509 .load(&dataset)
6510 .unwrap();
6511 let cache_values = cache_reference
6512 .evaluate_local(&[])
6513 .expect("evaluation should succeed");
6514 let expected_weighted_sum = cache_values
6515 .iter()
6516 .zip(dataset.events_local().iter())
6517 .fold(Complex64::ZERO, |acc, (value, event)| {
6518 acc + (*value * event.weight())
6519 });
6520 assert_relative_eq!(
6521 precomputed[0].weighted_cache_sum.re,
6522 expected_weighted_sum.re
6523 );
6524 assert_relative_eq!(
6525 precomputed[0].weighted_cache_sum.im,
6526 expected_weighted_sum.im
6527 );
6528 }
6529 #[test]
6530 fn test_expression_ir_precomputed_cached_integrals_empty_when_non_separable() {
6531 let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
6532 * &CacheOnlyScalar::new("k").unwrap();
6533 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6534 let evaluator = expr.load(&dataset).unwrap();
6535 assert!(evaluator
6536 .expression_precomputed_cached_integrals()
6537 .expect("integrals should exist")
6538 .is_empty());
6539 }
6540 #[test]
6541 fn test_expression_ir_precomputed_cached_integrals_recompute_on_activation_change() {
6542 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6543 * &CacheOnlyScalar::new("k").unwrap();
6544 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6545 let evaluator = expr.load(&dataset).unwrap();
6546 assert_eq!(
6547 evaluator
6548 .expression_precomputed_cached_integrals()
6549 .expect("integrals should exist")
6550 .len(),
6551 1
6552 );
6553
6554 evaluator.isolate_many(&["p"]);
6555 assert!(evaluator
6556 .expression_precomputed_cached_integrals()
6557 .expect("integrals should exist")
6558 .is_empty());
6559 }
6560 #[test]
6561 fn test_expression_ir_precomputed_cached_integrals_recompute_on_dataset_change() {
6562 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6563 * &CacheOnlyScalar::new("k").unwrap();
6564 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6565 let mut evaluator = expr.load(&dataset).unwrap();
6566 drop(dataset);
6567 let before = evaluator
6568 .expression_precomputed_cached_integrals()
6569 .expect("integrals should exist");
6570 assert_eq!(before.len(), 1);
6571
6572 Arc::get_mut(&mut evaluator.dataset)
6573 .expect("evaluator should own dataset Arc in this test")
6574 .clear_events_local();
6575 let after = evaluator
6576 .expression_precomputed_cached_integrals()
6577 .expect("integrals should exist");
6578 assert_eq!(after.len(), 1);
6579 assert_eq!(after[0].weighted_cache_sum, Complex64::ZERO);
6580 assert!(before[0].weighted_cache_sum != after[0].weighted_cache_sum);
6581 }
6582 #[test]
6583 fn test_expression_ir_precomputed_cached_integral_gradient_terms_scale_by_cache_integrals() {
6584 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6585 * &CacheOnlyScalar::new("k").unwrap();
6586 let dataset = Arc::new(Dataset::new(vec![
6587 Arc::new(test_event()),
6588 Arc::new(test_event()),
6589 ]));
6590 let evaluator = expr.load(&dataset).unwrap();
6591 let cached_integrals = evaluator
6592 .expression_precomputed_cached_integrals()
6593 .expect("integrals should exist");
6594 assert_eq!(cached_integrals.len(), 1);
6595 let gradient_terms = evaluator
6596 .expression_precomputed_cached_integral_gradient_terms(&[1.25])
6597 .expect("evaluation should succeed");
6598 assert_eq!(gradient_terms.len(), 1);
6599 assert_eq!(gradient_terms[0].weighted_gradient.len(), 1);
6600 assert_relative_eq!(
6601 gradient_terms[0].weighted_gradient[0].re,
6602 cached_integrals[0].weighted_cache_sum.re,
6603 epsilon = 1e-6
6604 );
6605 assert_relative_eq!(
6606 gradient_terms[0].weighted_gradient[0].im,
6607 cached_integrals[0].weighted_cache_sum.im,
6608 epsilon = 1e-6
6609 );
6610 }
6611 #[test]
6612 fn test_expression_ir_precomputed_cached_integral_gradient_terms_empty_when_not_separable() {
6613 let expr = TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap()
6614 * &CacheOnlyScalar::new("k").unwrap();
6615 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
6616 let evaluator = expr.load(&dataset).unwrap();
6617 assert!(evaluator
6618 .expression_precomputed_cached_integral_gradient_terms(&[0.1, -0.2])
6619 .expect("evaluation should succeed")
6620 .is_empty());
6621 }
6622 #[test]
6623 fn test_expression_ir_lowered_cached_factor_programs_match_ir_cached_paths() {
6624 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6625 * &CacheOnlyScalar::new("k").unwrap())
6626 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6627 let dataset = Arc::new(test_dataset());
6628 let evaluator = expr.load(&dataset).unwrap();
6629 let resources = evaluator.resources.read();
6630 let state = evaluator
6631 .ensure_cached_integral_cache_state(&resources)
6632 .expect("state should be available");
6633 let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
6634 let parameters = resources
6635 .parameter_map
6636 .assemble(&[0.55, 0.2, -0.15])
6637 .expect("parameters should assemble");
6638
6639 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
6640 evaluator.fill_amplitude_values(
6641 &mut amplitude_values,
6642 &state.execution_sets.cached_parameter_amplitudes,
6643 ¶meters,
6644 &resources.caches[0],
6645 );
6646 let cached_value_ir =
6647 evaluator.evaluate_cached_weighted_value_sum_ir(&state, &litude_values);
6648 let cached_value_lowered = evaluator
6649 .evaluate_cached_weighted_value_sum_lowered(
6650 &state,
6651 lowered_artifacts.as_ref(),
6652 &litude_values,
6653 )
6654 .expect("cached value lowering should succeed");
6655 assert_relative_eq!(cached_value_lowered, cached_value_ir, epsilon = 1e-12);
6656
6657 let mut cached_parameter_mask = vec![false; evaluator.amplitudes.len()];
6658 for &index in &state.execution_sets.cached_parameter_amplitudes {
6659 cached_parameter_mask[index] = true;
6660 }
6661 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
6662 .map(|_| DVector::zeros(parameters.len()))
6663 .collect::<Vec<_>>();
6664 evaluator.fill_amplitude_gradients(
6665 &mut amplitude_gradients,
6666 &cached_parameter_mask,
6667 ¶meters,
6668 &resources.caches[0],
6669 );
6670 let cached_gradient_ir = evaluator.evaluate_cached_weighted_gradient_sum_ir(
6671 &state,
6672 &litude_values,
6673 &litude_gradients,
6674 parameters.len(),
6675 );
6676 let cached_gradient_lowered = evaluator
6677 .evaluate_cached_weighted_gradient_sum_lowered(
6678 &state,
6679 lowered_artifacts.as_ref(),
6680 &litude_values,
6681 &litude_gradients,
6682 parameters.len(),
6683 )
6684 .expect("cached gradient lowering should succeed");
6685 for (lowered, ir) in cached_gradient_lowered
6686 .iter()
6687 .zip(cached_gradient_ir.iter())
6688 {
6689 assert_relative_eq!(*lowered, *ir, epsilon = 1e-12);
6690 }
6691 }
6692 #[test]
6693 fn test_expression_ir_lowered_residual_runtime_matches_zeroed_node_path() {
6694 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6695 * &CacheOnlyScalar::new("k").unwrap())
6696 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6697 let dataset = Arc::new(test_dataset());
6698 let evaluator = expr.load(&dataset).unwrap();
6699 let resources = evaluator.resources.read();
6700 let state = evaluator
6701 .ensure_cached_integral_cache_state(&resources)
6702 .expect("state should be available");
6703 let lowered_artifacts = evaluator.active_lowered_artifacts().unwrap();
6704 let parameters = resources
6705 .parameter_map
6706 .assemble(&[0.55, 0.2, -0.15])
6707 .expect("parameters should assemble");
6708
6709 let mut amplitude_values = vec![Complex64::ZERO; evaluator.amplitudes.len()];
6710 evaluator.fill_amplitude_values(
6711 &mut amplitude_values,
6712 &state.execution_sets.residual_amplitudes,
6713 ¶meters,
6714 &resources.caches[0],
6715 );
6716 let residual_value_ir = evaluator.evaluate_residual_value_ir(&state, &litude_values);
6717 let residual_program = lowered_artifacts
6718 .residual_runtime
6719 .as_ref()
6720 .map(|runtime| runtime.value_program())
6721 .expect("residual value lowering should succeed");
6722 let mut value_slots = vec![Complex64::ZERO; residual_program.scratch_slots()];
6723 let residual_value_lowered =
6724 residual_program.evaluate_into(&litude_values, &mut value_slots);
6725 assert_relative_eq!(
6726 residual_value_lowered.re,
6727 residual_value_ir.re,
6728 epsilon = 1e-12
6729 );
6730 assert_relative_eq!(
6731 residual_value_lowered.im,
6732 residual_value_ir.im,
6733 epsilon = 1e-12
6734 );
6735
6736 let mut residual_active_mask = vec![false; evaluator.amplitudes.len()];
6737 for &index in &state.execution_sets.residual_amplitudes {
6738 residual_active_mask[index] = true;
6739 }
6740 let mut amplitude_gradients = (0..evaluator.amplitudes.len())
6741 .map(|_| DVector::zeros(parameters.len()))
6742 .collect::<Vec<_>>();
6743 evaluator.fill_amplitude_gradients(
6744 &mut amplitude_gradients,
6745 &residual_active_mask,
6746 ¶meters,
6747 &resources.caches[0],
6748 );
6749 let residual_gradient_ir = evaluator.evaluate_residual_gradient_ir(
6750 &state,
6751 &litude_values,
6752 &litude_gradients,
6753 parameters.len(),
6754 );
6755
6756 let program = lowered_artifacts
6757 .residual_runtime
6758 .as_ref()
6759 .map(|runtime| runtime.gradient_program())
6760 .expect("gradient lowering should succeed");
6761 let mut value_slots = vec![Complex64::ZERO; program.scratch_slots()];
6762 let mut gradient_slots = vec![Complex64::ZERO; program.scratch_slots() * parameters.len()];
6763 let residual_gradient_lowered = program.evaluate_gradient_into_flat(
6764 &litude_values,
6765 &litude_gradients,
6766 &mut value_slots,
6767 &mut gradient_slots,
6768 parameters.len(),
6769 );
6770
6771 for (lowered, ir) in residual_gradient_lowered
6772 .iter()
6773 .zip(residual_gradient_ir.iter())
6774 {
6775 assert_relative_eq!(lowered.re, ir.re, epsilon = 1e-12);
6776 assert_relative_eq!(lowered.im, ir.im, epsilon = 1e-12);
6777 }
6778 }
6779 #[test]
6780 fn test_expression_ir_reuses_lowered_artifacts_when_dataset_key_changes() {
6781 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6782 * &CacheOnlyScalar::new("k").unwrap())
6783 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
6784 let dataset = Arc::new(test_dataset());
6785 let mut evaluator = expr.load(&dataset).unwrap();
6786 drop(dataset);
6787
6788 assert_eq!(evaluator.specialization_cache_len(), 1);
6789 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
6790
6791 evaluator.reset_expression_compile_metrics();
6792 evaluator.reset_expression_specialization_metrics();
6793
6794 Arc::get_mut(&mut evaluator.dataset)
6795 .expect("evaluator should own dataset Arc in this test")
6796 .clear_events_local();
6797
6798 let cached_integrals = evaluator
6799 .expression_precomputed_cached_integrals()
6800 .expect("integrals should exist");
6801 assert_eq!(cached_integrals.len(), 1);
6802 assert_eq!(cached_integrals[0].weighted_cache_sum, Complex64::ZERO);
6803
6804 assert_eq!(evaluator.specialization_cache_len(), 2);
6805 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
6806 assert_eq!(
6807 evaluator.expression_specialization_metrics(),
6808 ExpressionSpecializationMetrics {
6809 cache_hits: 0,
6810 cache_misses: 1,
6811 }
6812 );
6813
6814 let compile_metrics = evaluator.expression_compile_metrics();
6815 assert_eq!(compile_metrics.specialization_cache_hits, 0);
6816 assert_eq!(compile_metrics.specialization_cache_misses, 1);
6817 assert_eq!(compile_metrics.specialization_lowering_cache_hits, 1);
6818 assert_eq!(compile_metrics.specialization_lowering_cache_misses, 0);
6819 assert!(compile_metrics.specialization_ir_compile_nanos > 0);
6820 assert!(compile_metrics.specialization_cached_integrals_nanos > 0);
6821 assert_eq!(compile_metrics.specialization_lowering_nanos, 0);
6822 }
6823
6824 #[test]
6825 fn test_evaluate_weighted_gradient_sum_local_matches_eventwise_baseline() {
6826 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6827 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6828 let c1 = CacheOnlyScalar::new("c1").unwrap();
6829 let c2 = CacheOnlyScalar::new("c2").unwrap();
6830 let c3 = CacheOnlyScalar::new("c3").unwrap();
6831 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6832 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6833 let dataset = Arc::new(test_dataset());
6834 let evaluator = expr.load(&dataset).unwrap();
6835 assert_eq!(
6836 evaluator
6837 .expression_precomputed_cached_integrals()
6838 .expect("integrals should exist")
6839 .len(),
6840 2
6841 );
6842 let params = vec![0.2, -0.3, 1.1, -0.7];
6843 let expected = evaluator
6844 .evaluate_gradient_local(¶ms)
6845 .expect("evaluation should succeed")
6846 .iter()
6847 .zip(dataset.events_local().iter())
6848 .fold(
6849 DVector::zeros(params.len()),
6850 |mut accum, (gradient, event)| {
6851 accum += gradient.map(|value| value.re).scale(event.weight());
6852 accum
6853 },
6854 );
6855 let actual = evaluator
6856 .evaluate_weighted_gradient_sum_local(¶ms)
6857 .expect("evaluation should succeed");
6858 for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
6859 assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
6860 }
6861 }
6862
6863 #[test]
6864 fn test_evaluate_weighted_value_sum_local_matches_eventwise_baseline() {
6865 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6866 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6867 let c1 = CacheOnlyScalar::new("c1").unwrap();
6868 let c2 = CacheOnlyScalar::new("c2").unwrap();
6869 let c3 = CacheOnlyScalar::new("c3").unwrap();
6870 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6871 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6872 let dataset = Arc::new(test_dataset());
6873 let evaluator = expr.load(&dataset).unwrap();
6874 assert_eq!(
6875 evaluator
6876 .expression_precomputed_cached_integrals()
6877 .expect("integrals should exist")
6878 .len(),
6879 2
6880 );
6881 let params = vec![0.2, -0.3, 1.1, -0.7];
6882 let expected = evaluator
6883 .evaluate_local(¶ms)
6884 .expect("evaluation should succeed")
6885 .iter()
6886 .zip(dataset.events_local().iter())
6887 .fold(0.0, |accum, (value, event)| {
6888 accum + event.weight() * value.re
6889 });
6890 let actual = evaluator
6891 .evaluate_weighted_value_sum_local(¶ms)
6892 .expect("evaluation should succeed");
6893 assert_relative_eq!(actual, expected, epsilon = 1e-10);
6894 }
6895
6896 #[test]
6897 fn test_weighted_sums_match_hardcoded_reference_values() {
6898 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
6899 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
6900 let c1 = CacheOnlyScalar::new("c1").unwrap();
6901 let c2 = CacheOnlyScalar::new("c2").unwrap();
6902 let c3 = CacheOnlyScalar::new("c3").unwrap();
6903 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
6904 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
6905
6906 let metadata = Arc::new(DatasetMetadata::default());
6907 let dataset = Arc::new(Dataset::new_with_metadata(
6908 vec![
6909 Arc::new(EventData {
6910 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
6911 aux: vec![],
6912 weight: 0.5,
6913 }),
6914 Arc::new(EventData {
6915 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
6916 aux: vec![],
6917 weight: -1.25,
6918 }),
6919 Arc::new(EventData {
6920 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
6921 aux: vec![],
6922 weight: 2.0,
6923 }),
6924 ],
6925 metadata,
6926 ));
6927 let evaluator = expr.load(&dataset).unwrap();
6928 let params = vec![0.7, -1.1, 0.9, -0.4];
6929
6930 let weighted_value_sum = evaluator
6931 .evaluate_weighted_value_sum_local(¶ms)
6932 .expect("evaluation should succeed");
6933 assert_relative_eq!(weighted_value_sum, 22.7725, epsilon = 1e-12);
6934
6935 let weighted_gradient_sum = evaluator
6936 .evaluate_weighted_gradient_sum_local(¶ms)
6937 .expect("evaluation should succeed");
6938 let free_parameters = evaluator
6939 .free_parameters()
6940 .into_iter()
6941 .map(|name| name.to_string())
6942 .collect::<Vec<_>>();
6943 assert_eq!(free_parameters, vec!["p1", "p2", "m1r", "m1i"]);
6944 let expected_gradient = [43.925, 7.25, 28.525, 0.0];
6945 assert_eq!(weighted_gradient_sum.len(), expected_gradient.len());
6946 for (actual, expected) in weighted_gradient_sum.iter().zip(expected_gradient.iter()) {
6947 assert_relative_eq!(*actual, *expected, epsilon = 1e-9);
6948 }
6949 }
6950 #[test]
6951 fn test_evaluate_weighted_gradient_sum_local_respects_signed_cached_terms() {
6952 let expr = Expression::one()
6953 - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6954 * &CacheOnlyScalar::new("k").unwrap());
6955 let dataset = Arc::new(test_dataset());
6956 let evaluator = expr.load(&dataset).unwrap();
6957 assert_eq!(
6958 evaluator
6959 .expression_precomputed_cached_integrals()
6960 .expect("integrals should exist")
6961 .len(),
6962 1
6963 );
6964 assert_eq!(
6965 evaluator
6966 .expression_precomputed_cached_integrals()
6967 .expect("integrals should exist")[0]
6968 .coefficient,
6969 -1
6970 );
6971 let params = vec![0.75];
6972 let expected = evaluator
6973 .evaluate_gradient_local(¶ms)
6974 .expect("evaluation should succeed")
6975 .iter()
6976 .zip(dataset.events_local().iter())
6977 .fold(
6978 DVector::zeros(params.len()),
6979 |mut accum, (gradient, event)| {
6980 accum += gradient.map(|value| value.re).scale(event.weight());
6981 accum
6982 },
6983 );
6984 let actual = evaluator
6985 .evaluate_weighted_gradient_sum_local(¶ms)
6986 .expect("evaluation should succeed");
6987 for (actual_item, expected_item) in actual.iter().zip(expected.iter()) {
6988 assert_relative_eq!(*actual_item, *expected_item, epsilon = 1e-10);
6989 }
6990 }
6991 #[test]
6992 fn test_evaluate_weighted_value_sum_local_respects_signed_cached_terms() {
6993 let expr = Expression::one()
6994 - &(ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
6995 * &CacheOnlyScalar::new("k").unwrap());
6996 let dataset = Arc::new(test_dataset());
6997 let evaluator = expr.load(&dataset).unwrap();
6998 assert_eq!(
6999 evaluator
7000 .expression_precomputed_cached_integrals()
7001 .expect("integrals should exist")
7002 .len(),
7003 1
7004 );
7005 assert_eq!(
7006 evaluator
7007 .expression_precomputed_cached_integrals()
7008 .expect("integrals should exist")[0]
7009 .coefficient,
7010 -1
7011 );
7012 let params = vec![0.75];
7013 let expected = evaluator
7014 .evaluate_local(¶ms)
7015 .expect("evaluation should succeed")
7016 .iter()
7017 .zip(dataset.events_local().iter())
7018 .fold(0.0, |accum, (value, event)| {
7019 accum + event.weight() * value.re
7020 });
7021 let actual = evaluator
7022 .evaluate_weighted_value_sum_local(¶ms)
7023 .expect("evaluation should succeed");
7024 assert_relative_eq!(actual, expected, epsilon = 1e-10);
7025 }
7026 #[test]
7027 fn test_expression_ir_diagnostics_follow_activation_changes() {
7028 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
7029 * &CacheOnlyScalar::new("k").unwrap())
7030 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
7031 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
7032 let evaluator = expr.load(&dataset).unwrap();
7033
7034 let all_active = evaluator
7035 .expression_normalization_plan_explain()
7036 .expect("plan should exist");
7037 assert_eq!(all_active.cached_separable_nodes.len(), 1);
7038 assert_eq!(
7039 evaluator
7040 .expression_root_dependence()
7041 .expect("root dependence should exist"),
7042 ExpressionDependence::Mixed
7043 );
7044
7045 evaluator.isolate_many(&["p"]);
7046 let param_only = evaluator
7047 .expression_normalization_plan_explain()
7048 .expect("plan should exist");
7049 assert!(param_only.cached_separable_nodes.is_empty());
7050 assert_eq!(
7051 evaluator
7052 .expression_root_dependence()
7053 .expect("root dependence should exist"),
7054 ExpressionDependence::ParameterOnly
7055 );
7056 }
7057 #[test]
7058 fn test_expression_ir_specialization_cache_reuses_prior_mask_specializations() {
7059 let expr = (ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
7060 * &CacheOnlyScalar::new("k").unwrap())
7061 + &TestAmplitude::new("m", parameter!("mr"), parameter!("mi")).unwrap();
7062 let dataset = Arc::new(Dataset::new(vec![Arc::new(test_event())]));
7063 let evaluator = expr.load(&dataset).unwrap();
7064
7065 let initial_compile_metrics = evaluator.expression_compile_metrics();
7066 assert!(initial_compile_metrics.initial_ir_compile_nanos > 0);
7067 assert!(initial_compile_metrics.initial_cached_integrals_nanos > 0);
7068 assert!(initial_compile_metrics.initial_lowering_nanos > 0);
7069 assert_eq!(initial_compile_metrics.specialization_cache_hits, 0);
7070 assert_eq!(initial_compile_metrics.specialization_cache_misses, 0);
7071 assert_eq!(
7072 initial_compile_metrics.specialization_lowering_cache_hits,
7073 0
7074 );
7075 assert_eq!(
7076 initial_compile_metrics.specialization_lowering_cache_misses,
7077 1
7078 );
7079
7080 assert_eq!(evaluator.specialization_cache_len(), 1);
7081 assert_eq!(evaluator.lowered_artifact_cache_len(), 1);
7082 assert_eq!(
7083 evaluator.expression_specialization_metrics(),
7084 ExpressionSpecializationMetrics {
7085 cache_hits: 0,
7086 cache_misses: 1,
7087 }
7088 );
7089 let all_active_cached_integrals = evaluator
7090 .expression_precomputed_cached_integrals()
7091 .expect("integrals should exist");
7092
7093 evaluator.isolate_many(&["p"]);
7094 assert_eq!(evaluator.specialization_cache_len(), 2);
7095 assert_eq!(
7096 evaluator.expression_specialization_metrics(),
7097 ExpressionSpecializationMetrics {
7098 cache_hits: 0,
7099 cache_misses: 2,
7100 }
7101 );
7102 let after_cache_miss_metrics = evaluator.expression_compile_metrics();
7103 assert_eq!(after_cache_miss_metrics.specialization_cache_hits, 0);
7104 assert_eq!(after_cache_miss_metrics.specialization_cache_misses, 1);
7105 assert_eq!(
7106 after_cache_miss_metrics.specialization_lowering_cache_hits,
7107 0
7108 );
7109 assert_eq!(
7110 after_cache_miss_metrics.specialization_lowering_cache_misses,
7111 2
7112 );
7113 assert!(after_cache_miss_metrics.specialization_ir_compile_nanos > 0);
7114 assert!(after_cache_miss_metrics.specialization_cached_integrals_nanos > 0);
7115 assert!(after_cache_miss_metrics.specialization_lowering_nanos > 0);
7116 assert!(evaluator
7117 .expression_precomputed_cached_integrals()
7118 .expect("integrals should exist")
7119 .is_empty());
7120
7121 evaluator.activate_many(&["k", "m"]);
7122 assert_eq!(evaluator.specialization_cache_len(), 2);
7123 assert_eq!(
7124 evaluator.expression_specialization_metrics(),
7125 ExpressionSpecializationMetrics {
7126 cache_hits: 1,
7127 cache_misses: 2,
7128 }
7129 );
7130 assert_eq!(
7131 evaluator
7132 .expression_precomputed_cached_integrals()
7133 .expect("integrals should exist"),
7134 all_active_cached_integrals
7135 );
7136 let after_cache_hit_metrics = evaluator.expression_compile_metrics();
7137 assert_eq!(after_cache_hit_metrics.specialization_cache_hits, 1);
7138 assert_eq!(after_cache_hit_metrics.specialization_cache_misses, 1);
7139 assert_eq!(
7140 after_cache_hit_metrics.specialization_lowering_cache_hits,
7141 0
7142 );
7143 assert_eq!(
7144 after_cache_hit_metrics.specialization_lowering_cache_misses,
7145 2
7146 );
7147 assert!(after_cache_hit_metrics.specialization_cache_restore_nanos > 0);
7148 }
7149
7150 #[test]
7151 fn test_weighted_sums_match_baseline_after_activation_changes() {
7152 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
7153 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
7154 let c1 = CacheOnlyScalar::new("c1").unwrap();
7155 let c2 = CacheOnlyScalar::new("c2").unwrap();
7156 let c3 = CacheOnlyScalar::new("c3").unwrap();
7157 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
7158 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
7159 let dataset = Arc::new(test_dataset());
7160 let evaluator = expr.load(&dataset).unwrap();
7161 let params = vec![0.2, -0.3, 1.1, -0.7];
7162
7163 evaluator.isolate_many(&["p1", "c1", "m1", "c3"]);
7164
7165 let expected_value = evaluator
7166 .evaluate_local(¶ms)
7167 .expect("evaluation should succeed")
7168 .iter()
7169 .zip(dataset.events_local().iter())
7170 .fold(0.0, |accum, (value, event)| {
7171 accum + event.weight() * value.re
7172 });
7173 assert_relative_eq!(
7174 evaluator
7175 .evaluate_weighted_value_sum_local(¶ms)
7176 .expect("evaluation should succeed"),
7177 expected_value,
7178 epsilon = 1e-10
7179 );
7180 }
7181
7182 #[test]
7183 fn test_evaluate_local_does_not_depend_on_dataset_rows() {
7184 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
7185 .unwrap()
7186 .norm_sqr();
7187 let mut event1 = test_event();
7188 event1.p4s[0].t = 7.5;
7189 let mut event2 = test_event();
7190 event2.p4s[0].t = 8.25;
7191 let mut event3 = test_event();
7192 event3.p4s[0].t = 9.0;
7193 let dataset = Arc::new(Dataset::new_with_metadata(
7194 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
7195 Arc::new(DatasetMetadata::default()),
7196 ));
7197 let mut evaluator = expr.load(&dataset).unwrap();
7198 drop(dataset);
7199 let expected_len = evaluator.resources.read().caches.len();
7200 Arc::get_mut(&mut evaluator.dataset)
7201 .expect("evaluator should own dataset Arc in this test")
7202 .clear_events_local();
7203 let cached = evaluator
7204 .evaluate_local(&[1.25, -0.75])
7205 .expect("evaluation should succeed");
7206 assert_eq!(cached.len(), expected_len);
7207 }
7208
7209 #[test]
7210 fn test_evaluate_gradient_local_does_not_depend_on_dataset_rows() {
7211 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
7212 .unwrap()
7213 .norm_sqr();
7214 let mut event1 = test_event();
7215 event1.p4s[0].t = 7.5;
7216 let mut event2 = test_event();
7217 event2.p4s[0].t = 8.25;
7218 let mut event3 = test_event();
7219 event3.p4s[0].t = 9.0;
7220 let dataset = Arc::new(Dataset::new_with_metadata(
7221 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
7222 Arc::new(DatasetMetadata::default()),
7223 ));
7224 let mut evaluator = expr.load(&dataset).unwrap();
7225 drop(dataset);
7226 let expected_len = evaluator.resources.read().caches.len();
7227 Arc::get_mut(&mut evaluator.dataset)
7228 .expect("evaluator should own dataset Arc in this test")
7229 .clear_events_local();
7230 let cached = evaluator
7231 .evaluate_gradient_local(&[1.25, -0.75])
7232 .expect("evaluation should succeed");
7233 assert_eq!(cached.len(), expected_len);
7234 }
7235
7236 #[test]
7237 fn test_evaluate_with_gradient_local_matches_separate_paths() {
7238 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
7239 .unwrap()
7240 .norm_sqr();
7241 let dataset = Arc::new(Dataset::new(vec![
7242 Arc::new(test_event()),
7243 Arc::new(test_event()),
7244 Arc::new(test_event()),
7245 ]));
7246 let evaluator = expr.load(&dataset).unwrap();
7247 let params = [1.25, -0.75];
7248 let values = evaluator
7249 .evaluate_local(¶ms)
7250 .expect("evaluation should succeed");
7251 let gradients = evaluator
7252 .evaluate_gradient_local(¶ms)
7253 .expect("evaluation should succeed");
7254 let fused = evaluator
7255 .evaluate_with_gradient_local(¶ms)
7256 .expect("evaluation should succeed");
7257 assert_eq!(fused.len(), values.len());
7258 assert_eq!(fused.len(), gradients.len());
7259 for ((value_gradient, value), gradient) in
7260 fused.iter().zip(values.iter()).zip(gradients.iter())
7261 {
7262 let (fused_value, fused_gradient) = value_gradient;
7263 assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
7264 assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
7265 assert_eq!(fused_gradient.len(), gradient.len());
7266 for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
7267 assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
7268 assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
7269 }
7270 }
7271 }
7272
7273 #[test]
7274 fn test_evaluate_with_gradient_batch_local_matches_separate_paths() {
7275 let expr = TestAmplitude::new("test", parameter!("real"), parameter!("imag"))
7276 .unwrap()
7277 .norm_sqr();
7278 let dataset = Arc::new(Dataset::new(vec![
7279 Arc::new(test_event()),
7280 Arc::new(test_event()),
7281 Arc::new(test_event()),
7282 Arc::new(test_event()),
7283 ]));
7284 let evaluator = expr.load(&dataset).unwrap();
7285 let params = [0.5, -1.25];
7286 let indices = vec![0, 2, 3];
7287 let values = evaluator
7288 .evaluate_batch_local(¶ms, &indices)
7289 .expect("evaluation should succeed");
7290 let gradients = evaluator
7291 .evaluate_gradient_batch_local(¶ms, &indices)
7292 .expect("evaluation should succeed");
7293 let fused = evaluator
7294 .evaluate_with_gradient_batch_local(¶ms, &indices)
7295 .expect("evaluation should succeed");
7296 assert_eq!(fused.len(), values.len());
7297 assert_eq!(fused.len(), gradients.len());
7298 for ((value_gradient, value), gradient) in
7299 fused.iter().zip(values.iter()).zip(gradients.iter())
7300 {
7301 let (fused_value, fused_gradient) = value_gradient;
7302 assert_relative_eq!(fused_value.re, value.re, epsilon = 1e-12);
7303 assert_relative_eq!(fused_value.im, value.im, epsilon = 1e-12);
7304 assert_eq!(fused_gradient.len(), gradient.len());
7305 for (fused_item, item) in fused_gradient.iter().zip(gradient.iter()) {
7306 assert_relative_eq!(fused_item.re, item.re, epsilon = 1e-12);
7307 assert_relative_eq!(fused_item.im, item.im, epsilon = 1e-12);
7308 }
7309 }
7310 }
7311
7312 #[test]
7313 fn test_precompute_all_columnar_populates_cache() {
7314 let mut event1 = test_event();
7315 event1.p4s[0].t = 7.5;
7316 let mut event2 = test_event();
7317 event2.p4s[0].t = 8.25;
7318 let mut event3 = test_event();
7319 event3.p4s[0].t = 9.0;
7320 let dataset = Dataset::new_with_metadata(
7321 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
7322 Arc::new(DatasetMetadata::default()),
7323 );
7324 let mut amplitude = TestAmplitude {
7325 name: "test".to_string(),
7326 re: parameter!("real"),
7327 pid_re: ParameterID::default(),
7328 im: parameter!("imag"),
7329 pid_im: ParameterID::default(),
7330 beam_energy: Default::default(),
7331 };
7332 let mut resources = Resources::default();
7333 amplitude
7334 .register(&mut resources)
7335 .expect("test amplitude should register");
7336 resources.reserve_cache(dataset.n_events());
7337 amplitude.precompute_all(&dataset, &mut resources);
7338 for cache in &resources.caches {
7339 assert!(cache.get_scalar(amplitude.beam_energy) > 0.0);
7340 }
7341 }
7342
7343 #[cfg(feature = "mpi")]
7344 #[mpi_test(np = [2])]
7345 fn test_load_reserves_local_cache_size_in_mpi() {
7346 use crate::mpi::{finalize_mpi, get_world, use_mpi};
7347
7348 use_mpi(true);
7349 assert!(get_world().is_some(), "MPI world should be initialized");
7350
7351 let expr = ComplexScalar::new(
7352 "constant",
7353 parameter!("const_re", 2.0),
7354 parameter!("const_im", 3.0),
7355 )
7356 .expect("constant amplitude should construct");
7357 let events = vec![
7358 Arc::new(test_event()),
7359 Arc::new(test_event()),
7360 Arc::new(test_event()),
7361 Arc::new(test_event()),
7362 ];
7363 let dataset = Arc::new(Dataset::new_with_metadata(
7364 events,
7365 Arc::new(DatasetMetadata::default()),
7366 ));
7367 let evaluator = expr.load(&dataset).expect("evaluator should load");
7368 let local_events = dataset.n_events_local();
7369 let cache_len = evaluator.resources.read().caches.len();
7370
7371 assert_eq!(
7372 cache_len, local_events,
7373 "cache length must match local event count under MPI"
7374 );
7375 finalize_mpi();
7376 }
7377
7378 #[cfg(feature = "mpi")]
7379 #[mpi_test(np = [2])]
7380 fn test_expression_ir_cached_integrals_are_rank_local_in_mpi() {
7381 use crate::mpi::{finalize_mpi, get_world, use_mpi};
7382 use mpi::{collective::SystemOperation, topology::Communicator, traits::*};
7383
7384 use_mpi(true);
7385 let world = get_world().expect("MPI world should be initialized");
7386
7387 let expr = ParameterOnlyScalar::new("p", parameter!("p")).unwrap()
7388 * &CacheOnlyScalar::new("k").unwrap();
7389 let events = vec![
7390 Arc::new(EventData {
7391 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
7392 aux: vec![],
7393 weight: 0.5,
7394 }),
7395 Arc::new(EventData {
7396 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
7397 aux: vec![],
7398 weight: 1.0,
7399 }),
7400 Arc::new(EventData {
7401 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
7402 aux: vec![],
7403 weight: 1.5,
7404 }),
7405 Arc::new(EventData {
7406 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
7407 aux: vec![],
7408 weight: 2.0,
7409 }),
7410 Arc::new(EventData {
7411 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
7412 aux: vec![],
7413 weight: 2.5,
7414 }),
7415 Arc::new(EventData {
7416 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
7417 aux: vec![],
7418 weight: 3.0,
7419 }),
7420 Arc::new(EventData {
7421 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
7422 aux: vec![],
7423 weight: 3.5,
7424 }),
7425 Arc::new(EventData {
7426 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
7427 aux: vec![],
7428 weight: 4.0,
7429 }),
7430 ];
7431 let dataset = Arc::new(Dataset::new_with_metadata(
7432 events,
7433 Arc::new(DatasetMetadata::default()),
7434 ));
7435 let evaluator = expr.load(&dataset).expect("evaluator should load");
7436 let cached_integrals = evaluator
7437 .expression_precomputed_cached_integrals()
7438 .expect("integrals should exist");
7439 assert_eq!(cached_integrals.len(), 1);
7440
7441 let local_expected = dataset.events_local().iter().fold(0.0, |acc, event| {
7442 acc + event.weight() * event.data().p4s[0].e()
7443 });
7444 let cached_local = cached_integrals[0].weighted_cache_sum;
7445 assert_relative_eq!(cached_local.re, local_expected, epsilon = 1e-12);
7446 assert_relative_eq!(cached_local.im, 0.0, epsilon = 1e-12);
7447
7448 let weighted_value_sum = evaluator
7449 .evaluate_weighted_value_sum_local(&[2.0])
7450 .expect("evaluate should succeed");
7451 assert_relative_eq!(weighted_value_sum, 2.0 * local_expected, epsilon = 1e-10);
7452
7453 let mut global_expected = 0.0;
7454 world.all_reduce_into(
7455 &local_expected,
7456 &mut global_expected,
7457 SystemOperation::sum(),
7458 );
7459 if world.size() > 1 {
7460 assert!(
7461 (cached_local.re - global_expected).abs() > 1e-12,
7462 "cached integral should remain rank-local before MPI reduction"
7463 );
7464 }
7465 finalize_mpi();
7466 }
7467
7468 #[cfg(feature = "mpi")]
7469 #[mpi_test(np = [2])]
7470 fn test_expression_ir_weighted_sum_mpi_matches_global_eventwise_baseline() {
7471 use crate::mpi::{finalize_mpi, get_world, use_mpi};
7472 use mpi::{collective::SystemOperation, traits::*};
7473
7474 use_mpi(true);
7475 let world = get_world().expect("MPI world should be initialized");
7476
7477 let p1 = ParameterOnlyScalar::new("p1", parameter!("p1")).unwrap();
7478 let p2 = ParameterOnlyScalar::new("p2", parameter!("p2")).unwrap();
7479 let c1 = CacheOnlyScalar::new("c1").unwrap();
7480 let c2 = CacheOnlyScalar::new("c2").unwrap();
7481 let c3 = CacheOnlyScalar::new("c3").unwrap();
7482 let m1 = TestAmplitude::new("m1", parameter!("m1r"), parameter!("m1i")).unwrap();
7483 let expr = (&p1 * &c1) + &(&p2 * &c2) + &(&(&m1 * &p1) * &c3);
7484 let events = vec![
7485 Arc::new(EventData {
7486 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 1.0)],
7487 aux: vec![],
7488 weight: 0.5,
7489 }),
7490 Arc::new(EventData {
7491 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 2.0)],
7492 aux: vec![],
7493 weight: -1.25,
7494 }),
7495 Arc::new(EventData {
7496 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 3.0)],
7497 aux: vec![],
7498 weight: 0.75,
7499 }),
7500 Arc::new(EventData {
7501 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 4.0)],
7502 aux: vec![],
7503 weight: 1.5,
7504 }),
7505 Arc::new(EventData {
7506 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 5.0)],
7507 aux: vec![],
7508 weight: 2.25,
7509 }),
7510 Arc::new(EventData {
7511 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 6.0)],
7512 aux: vec![],
7513 weight: -0.5,
7514 }),
7515 Arc::new(EventData {
7516 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 7.0)],
7517 aux: vec![],
7518 weight: 3.5,
7519 }),
7520 Arc::new(EventData {
7521 p4s: vec![Vec4::new(0.0, 0.0, 0.0, 8.0)],
7522 aux: vec![],
7523 weight: 1.25,
7524 }),
7525 ];
7526 let dataset = Arc::new(Dataset::new_with_metadata(
7527 events,
7528 Arc::new(DatasetMetadata::default()),
7529 ));
7530 let evaluator = expr.load(&dataset).expect("evaluator should load");
7531 let params = vec![0.2, -0.3, 1.1, -0.7];
7532
7533 let local_expected_value = evaluator
7534 .evaluate_local(¶ms)
7535 .expect("evaluate should succeed")
7536 .iter()
7537 .zip(dataset.events_local().iter())
7538 .fold(0.0, |accum, (value, event)| {
7539 accum + event.weight() * value.re
7540 });
7541 let mut global_expected_value = 0.0;
7542 world.all_reduce_into(
7543 &local_expected_value,
7544 &mut global_expected_value,
7545 SystemOperation::sum(),
7546 );
7547 let mpi_value = evaluator
7548 .evaluate_weighted_value_sum_mpi(¶ms, &world)
7549 .expect("evaluate should succeed");
7550 assert_relative_eq!(mpi_value, global_expected_value, epsilon = 1e-10);
7551
7552 let local_expected_gradient = evaluator
7553 .evaluate_gradient_local(¶ms)
7554 .expect("evaluate should succeed")
7555 .iter()
7556 .zip(dataset.events_local().iter())
7557 .fold(
7558 DVector::zeros(params.len()),
7559 |mut accum, (gradient, event)| {
7560 accum += gradient.map(|value| value.re).scale(event.weight());
7561 accum
7562 },
7563 );
7564 let mut global_expected_gradient = vec![0.0; local_expected_gradient.len()];
7565 world.all_reduce_into(
7566 local_expected_gradient.as_slice(),
7567 &mut global_expected_gradient,
7568 SystemOperation::sum(),
7569 );
7570 let mpi_gradient = evaluator
7571 .evaluate_weighted_gradient_sum_mpi(¶ms, &world)
7572 .expect("evaluate should succeed");
7573 for (actual, expected) in mpi_gradient.iter().zip(global_expected_gradient.iter()) {
7574 assert_relative_eq!(*actual, *expected, epsilon = 1e-10);
7575 }
7576
7577 finalize_mpi();
7578 }
7579
7580 #[test]
7581 fn test_evaluate_local_succeeds_for_constant_amplitude() {
7582 let expr = ComplexScalar::new(
7583 "constant",
7584 parameter!("const_re", 2.0),
7585 parameter!("const_im", 3.0),
7586 )
7587 .unwrap();
7588 let dataset = Arc::new(Dataset::new_with_metadata(
7589 vec![Arc::new(test_event())],
7590 Arc::new(DatasetMetadata::default()),
7591 ));
7592 let evaluator = expr.load(&dataset).unwrap();
7593 let values = evaluator
7594 .evaluate_local(&[])
7595 .expect("evaluation should succeed");
7596 assert_eq!(values.len(), 1);
7597 let gradients = evaluator
7598 .evaluate_gradient_local(&[])
7599 .expect("evaluation should succeed");
7600 assert_eq!(gradients.len(), 1);
7601 }
7602
7603 #[test]
7604 fn test_constant_amplitude() {
7605 let expr = ComplexScalar::new(
7606 "constant",
7607 parameter!("const_re", 2.0),
7608 parameter!("const_im", 3.0),
7609 )
7610 .unwrap();
7611 let dataset = Arc::new(Dataset::new_with_metadata(
7612 vec![Arc::new(test_event())],
7613 Arc::new(DatasetMetadata::default()),
7614 ));
7615 let evaluator = expr.load(&dataset).unwrap();
7616 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7617 assert_eq!(result[0], Complex64::new(2.0, 3.0));
7618 }
7619
7620 #[test]
7621 fn test_parametric_amplitude() {
7622 let expr = ComplexScalar::new(
7623 "parametric",
7624 parameter!("test_param_re"),
7625 parameter!("test_param_im"),
7626 )
7627 .unwrap();
7628 let dataset = Arc::new(test_dataset());
7629 let evaluator = expr.load(&dataset).unwrap();
7630 let result = evaluator
7631 .evaluate(&[2.0, 3.0])
7632 .expect("evaluation should succeed");
7633 assert_eq!(result[0], Complex64::new(2.0, 3.0));
7634 }
7635
7636 #[test]
7637 fn test_expression_operations() {
7638 let expr1 = ComplexScalar::new(
7639 "const1",
7640 parameter!("const1_re", 2.0),
7641 parameter!("const1_im", 0.0),
7642 )
7643 .unwrap();
7644 let expr2 = ComplexScalar::new(
7645 "const2",
7646 parameter!("const2_re", 0.0),
7647 parameter!("const2_im", 1.0),
7648 )
7649 .unwrap();
7650 let expr3 = ComplexScalar::new(
7651 "const3",
7652 parameter!("const3_re", 3.0),
7653 parameter!("const3_im", 4.0),
7654 )
7655 .unwrap();
7656
7657 let dataset = Arc::new(test_dataset());
7658
7659 let expr_add = &expr1 + &expr2;
7661 let result_add = expr_add
7662 .load(&dataset)
7663 .unwrap()
7664 .evaluate(&[])
7665 .expect("evaluation should succeed");
7666 assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
7667
7668 let expr_sub = &expr1 - &expr2;
7670 let result_sub = expr_sub
7671 .load(&dataset)
7672 .unwrap()
7673 .evaluate(&[])
7674 .expect("evaluation should succeed");
7675 assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
7676
7677 let expr_mul = &expr1 * &expr2;
7679 let result_mul = expr_mul
7680 .load(&dataset)
7681 .unwrap()
7682 .evaluate(&[])
7683 .expect("evaluation should succeed");
7684 assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
7685
7686 let expr_div = &expr1 / &expr3;
7688 let result_div = expr_div
7689 .load(&dataset)
7690 .unwrap()
7691 .evaluate(&[])
7692 .expect("evaluation should succeed");
7693 assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
7694
7695 let expr_neg = -&expr3;
7697 let result_neg = expr_neg
7698 .load(&dataset)
7699 .unwrap()
7700 .evaluate(&[])
7701 .expect("evaluation should succeed");
7702 assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
7703
7704 let expr_add2 = &expr_add + &expr_mul;
7706 let result_add2 = expr_add2
7707 .load(&dataset)
7708 .unwrap()
7709 .evaluate(&[])
7710 .expect("evaluation should succeed");
7711 assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
7712
7713 let expr_sub2 = &expr_add - &expr_mul;
7715 let result_sub2 = expr_sub2
7716 .load(&dataset)
7717 .unwrap()
7718 .evaluate(&[])
7719 .expect("evaluation should succeed");
7720 assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
7721
7722 let expr_mul2 = &expr_add * &expr_mul;
7724 let result_mul2 = expr_mul2
7725 .load(&dataset)
7726 .unwrap()
7727 .evaluate(&[])
7728 .expect("evaluation should succeed");
7729 assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
7730
7731 let expr_div2 = &expr_add / &expr_add2;
7733 let result_div2 = expr_div2
7734 .load(&dataset)
7735 .unwrap()
7736 .evaluate(&[])
7737 .expect("evaluation should succeed");
7738 assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
7739
7740 let expr_neg2 = -&expr_mul2;
7742 let result_neg2 = expr_neg2
7743 .load(&dataset)
7744 .unwrap()
7745 .evaluate(&[])
7746 .expect("evaluation should succeed");
7747 assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
7748
7749 let expr_real = expr3.real();
7751 let result_real = expr_real
7752 .load(&dataset)
7753 .unwrap()
7754 .evaluate(&[])
7755 .expect("evaluation should succeed");
7756 assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
7757
7758 let expr_mul2_real = expr_mul2.real();
7760 let result_mul2_real = expr_mul2_real
7761 .load(&dataset)
7762 .unwrap()
7763 .evaluate(&[])
7764 .expect("evaluation should succeed");
7765 assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
7766
7767 let expr_imag = expr3.imag();
7769 let result_imag = expr_imag
7770 .load(&dataset)
7771 .unwrap()
7772 .evaluate(&[])
7773 .expect("evaluation should succeed");
7774 assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
7775
7776 let expr_mul2_imag = expr_mul2.imag();
7778 let result_mul2_imag = expr_mul2_imag
7779 .load(&dataset)
7780 .unwrap()
7781 .evaluate(&[])
7782 .expect("evaluation should succeed");
7783 assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
7784
7785 let expr_conj = expr3.conj();
7787 let result_conj = expr_conj
7788 .load(&dataset)
7789 .unwrap()
7790 .evaluate(&[])
7791 .expect("evaluation should succeed");
7792 assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
7793
7794 let expr_mul2_conj = expr_mul2.conj();
7796 let result_mul2_conj = expr_mul2_conj
7797 .load(&dataset)
7798 .unwrap()
7799 .evaluate(&[])
7800 .expect("evaluation should succeed");
7801 assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
7802
7803 let expr_norm = expr1.norm_sqr();
7805 let result_norm = expr_norm
7806 .load(&dataset)
7807 .unwrap()
7808 .evaluate(&[])
7809 .expect("evaluation should succeed");
7810 assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
7811
7812 let expr_mul2_norm = expr_mul2.norm_sqr();
7814 let result_mul2_norm = expr_mul2_norm
7815 .load(&dataset)
7816 .unwrap()
7817 .evaluate(&[])
7818 .expect("evaluation should succeed");
7819 assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
7820 }
7821
7822 #[test]
7823 fn test_amplitude_activation() {
7824 let expr1 = ComplexScalar::new(
7825 "const1",
7826 parameter!("const1_re_act", 1.0),
7827 parameter!("const1_im_act", 0.0),
7828 )
7829 .unwrap();
7830 let expr2 = ComplexScalar::new(
7831 "const2",
7832 parameter!("const2_re_act", 2.0),
7833 parameter!("const2_im_act", 0.0),
7834 )
7835 .unwrap();
7836
7837 let dataset = Arc::new(test_dataset());
7838 let expr = &expr1 + &expr2;
7839 let evaluator = expr.load(&dataset).unwrap();
7840
7841 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7843 assert_eq!(result[0], Complex64::new(3.0, 0.0));
7844
7845 evaluator.deactivate_strict("const1").unwrap();
7847 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7848 assert_eq!(result[0], Complex64::new(2.0, 0.0));
7849
7850 evaluator.isolate_strict("const1").unwrap();
7852 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7853 assert_eq!(result[0], Complex64::new(1.0, 0.0));
7854
7855 evaluator.activate_all();
7857 let result = evaluator.evaluate(&[]).expect("evaluation should succeed");
7858 assert_eq!(result[0], Complex64::new(3.0, 0.0));
7859 }
7860
7861 #[test]
7862 fn test_gradient() {
7863 let expr1 = ComplexScalar::new(
7864 "parametric_1",
7865 parameter!("test_param_re_1"),
7866 parameter!("test_param_im_1"),
7867 )
7868 .unwrap();
7869 let expr2 = ComplexScalar::new(
7870 "parametric_2",
7871 parameter!("test_param_re_2"),
7872 parameter!("test_param_im_2"),
7873 )
7874 .unwrap();
7875
7876 let dataset = Arc::new(test_dataset());
7877 let params = vec![2.0, 3.0, 4.0, 5.0];
7878
7879 let expr = &expr1 + &expr2;
7880 let evaluator = expr.load(&dataset).unwrap();
7881
7882 let gradient = evaluator
7883 .evaluate_gradient(¶ms)
7884 .expect("evaluation should succeed");
7885
7886 assert_relative_eq!(gradient[0][0].re, 1.0);
7887 assert_relative_eq!(gradient[0][0].im, 0.0);
7888 assert_relative_eq!(gradient[0][1].re, 0.0);
7889 assert_relative_eq!(gradient[0][1].im, 1.0);
7890 assert_relative_eq!(gradient[0][2].re, 1.0);
7891 assert_relative_eq!(gradient[0][2].im, 0.0);
7892 assert_relative_eq!(gradient[0][3].re, 0.0);
7893 assert_relative_eq!(gradient[0][3].im, 1.0);
7894
7895 let expr = &expr1 - &expr2;
7896 let evaluator = expr.load(&dataset).unwrap();
7897
7898 let gradient = evaluator
7899 .evaluate_gradient(¶ms)
7900 .expect("evaluation should succeed");
7901
7902 assert_relative_eq!(gradient[0][0].re, 1.0);
7903 assert_relative_eq!(gradient[0][0].im, 0.0);
7904 assert_relative_eq!(gradient[0][1].re, 0.0);
7905 assert_relative_eq!(gradient[0][1].im, 1.0);
7906 assert_relative_eq!(gradient[0][2].re, -1.0);
7907 assert_relative_eq!(gradient[0][2].im, 0.0);
7908 assert_relative_eq!(gradient[0][3].re, 0.0);
7909 assert_relative_eq!(gradient[0][3].im, -1.0);
7910
7911 let expr = &expr1 * &expr2;
7912 let evaluator = expr.load(&dataset).unwrap();
7913
7914 let gradient = evaluator
7915 .evaluate_gradient(¶ms)
7916 .expect("evaluation should succeed");
7917
7918 assert_relative_eq!(gradient[0][0].re, 4.0);
7919 assert_relative_eq!(gradient[0][0].im, 5.0);
7920 assert_relative_eq!(gradient[0][1].re, -5.0);
7921 assert_relative_eq!(gradient[0][1].im, 4.0);
7922 assert_relative_eq!(gradient[0][2].re, 2.0);
7923 assert_relative_eq!(gradient[0][2].im, 3.0);
7924 assert_relative_eq!(gradient[0][3].re, -3.0);
7925 assert_relative_eq!(gradient[0][3].im, 2.0);
7926
7927 let expr = &expr1 / &expr2;
7928 let evaluator = expr.load(&dataset).unwrap();
7929
7930 let gradient = evaluator
7931 .evaluate_gradient(¶ms)
7932 .expect("evaluation should succeed");
7933
7934 assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
7935 assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
7936 assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
7937 assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
7938 assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
7939 assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
7940 assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
7941 assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
7942
7943 let expr = -(&expr1 * &expr2);
7944 let evaluator = expr.load(&dataset).unwrap();
7945
7946 let gradient = evaluator
7947 .evaluate_gradient(¶ms)
7948 .expect("evaluation should succeed");
7949
7950 assert_relative_eq!(gradient[0][0].re, -4.0);
7951 assert_relative_eq!(gradient[0][0].im, -5.0);
7952 assert_relative_eq!(gradient[0][1].re, 5.0);
7953 assert_relative_eq!(gradient[0][1].im, -4.0);
7954 assert_relative_eq!(gradient[0][2].re, -2.0);
7955 assert_relative_eq!(gradient[0][2].im, -3.0);
7956 assert_relative_eq!(gradient[0][3].re, 3.0);
7957 assert_relative_eq!(gradient[0][3].im, -2.0);
7958
7959 let expr = (&expr1 * &expr2).real();
7960 let evaluator = expr.load(&dataset).unwrap();
7961
7962 let gradient = evaluator
7963 .evaluate_gradient(¶ms)
7964 .expect("evaluation should succeed");
7965
7966 assert_relative_eq!(gradient[0][0].re, 4.0);
7967 assert_relative_eq!(gradient[0][0].im, 0.0);
7968 assert_relative_eq!(gradient[0][1].re, -5.0);
7969 assert_relative_eq!(gradient[0][1].im, 0.0);
7970 assert_relative_eq!(gradient[0][2].re, 2.0);
7971 assert_relative_eq!(gradient[0][2].im, 0.0);
7972 assert_relative_eq!(gradient[0][3].re, -3.0);
7973 assert_relative_eq!(gradient[0][3].im, 0.0);
7974
7975 let expr = (&expr1 * &expr2).imag();
7976 let evaluator = expr.load(&dataset).unwrap();
7977
7978 let gradient = evaluator
7979 .evaluate_gradient(¶ms)
7980 .expect("evaluation should succeed");
7981
7982 assert_relative_eq!(gradient[0][0].re, 5.0);
7983 assert_relative_eq!(gradient[0][0].im, 0.0);
7984 assert_relative_eq!(gradient[0][1].re, 4.0);
7985 assert_relative_eq!(gradient[0][1].im, 0.0);
7986 assert_relative_eq!(gradient[0][2].re, 3.0);
7987 assert_relative_eq!(gradient[0][2].im, 0.0);
7988 assert_relative_eq!(gradient[0][3].re, 2.0);
7989 assert_relative_eq!(gradient[0][3].im, 0.0);
7990
7991 let expr = (&expr1 * &expr2).conj();
7992 let evaluator = expr.load(&dataset).unwrap();
7993
7994 let gradient = evaluator
7995 .evaluate_gradient(¶ms)
7996 .expect("evaluation should succeed");
7997
7998 assert_relative_eq!(gradient[0][0].re, 4.0);
7999 assert_relative_eq!(gradient[0][0].im, -5.0);
8000 assert_relative_eq!(gradient[0][1].re, -5.0);
8001 assert_relative_eq!(gradient[0][1].im, -4.0);
8002 assert_relative_eq!(gradient[0][2].re, 2.0);
8003 assert_relative_eq!(gradient[0][2].im, -3.0);
8004 assert_relative_eq!(gradient[0][3].re, -3.0);
8005 assert_relative_eq!(gradient[0][3].im, -2.0);
8006
8007 let expr = (&expr1 * &expr2).norm_sqr();
8008 let evaluator = expr.load(&dataset).unwrap();
8009
8010 let gradient = evaluator
8011 .evaluate_gradient(¶ms)
8012 .expect("evaluation should succeed");
8013
8014 assert_relative_eq!(gradient[0][0].re, 164.0);
8015 assert_relative_eq!(gradient[0][0].im, 0.0);
8016 assert_relative_eq!(gradient[0][1].re, 246.0);
8017 assert_relative_eq!(gradient[0][1].im, 0.0);
8018 assert_relative_eq!(gradient[0][2].re, 104.0);
8019 assert_relative_eq!(gradient[0][2].im, 0.0);
8020 assert_relative_eq!(gradient[0][3].re, 130.0);
8021 assert_relative_eq!(gradient[0][3].im, 0.0);
8022 }
8023
8024 #[test]
8025 fn test_expression_function_gradients() {
8026 let expr1 = ComplexScalar::new(
8027 "function_parametric_1",
8028 parameter!("function_test_param_re_1"),
8029 parameter!("function_test_param_im_1"),
8030 )
8031 .unwrap();
8032 let expr2 = ComplexScalar::new(
8033 "function_parametric_2",
8034 parameter!("function_test_param_re_2"),
8035 parameter!("function_test_param_im_2"),
8036 )
8037 .unwrap();
8038
8039 let sin = expr1.sin();
8040 let cos = expr1.cos();
8041 let trig = &sin * &cos;
8042 let pow = expr1.pow(&expr2);
8043 let mut expr = expr1.sqrt();
8044 expr = &expr + &expr1.exp();
8045 expr = &expr + &expr1.powi(2);
8046 expr = &expr + &expr1.powf(1.7);
8047 expr = &expr + &trig;
8048 expr = &expr + &expr1.log();
8049 expr = &expr + &expr1.cis();
8050 expr = &expr + &pow;
8051
8052 let dataset = Arc::new(test_dataset());
8053 let evaluator = expr.load(&dataset).unwrap();
8054 let params = vec![2.0, 0.5, 1.2, -0.3];
8055 let gradient = evaluator
8056 .evaluate_gradient(¶ms)
8057 .expect("evaluation should succeed");
8058 let eps = 1e-6;
8059
8060 for param_index in 0..params.len() {
8061 let mut plus = params.clone();
8062 plus[param_index] += eps;
8063 let mut minus = params.clone();
8064 minus[param_index] -= eps;
8065 let finite_diff = (evaluator
8066 .evaluate(&plus)
8067 .expect("evaluation should succeed")[0]
8068 - evaluator
8069 .evaluate(&minus)
8070 .expect("evaluation should succeed")[0])
8071 / Complex64::new(2.0 * eps, 0.0);
8072
8073 assert_relative_eq!(
8074 gradient[0][param_index].re,
8075 finite_diff.re,
8076 epsilon = 1e-6,
8077 max_relative = 1e-6
8078 );
8079 assert_relative_eq!(
8080 gradient[0][param_index].im,
8081 finite_diff.im,
8082 epsilon = 1e-6,
8083 max_relative = 1e-6
8084 );
8085 }
8086 }
8087
8088 #[test]
8089 fn test_zeros_and_ones() {
8090 let amp = ComplexScalar::new(
8091 "parametric",
8092 parameter!("test_param_re"),
8093 parameter!("fixed_two", 2.0),
8094 )
8095 .unwrap();
8096 let dataset = Arc::new(test_dataset());
8097 let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
8098 let evaluator = expr.load(&dataset).unwrap();
8099
8100 let params = vec![2.0];
8101 let value = evaluator
8102 .evaluate(¶ms)
8103 .expect("evaluation should succeed");
8104 let gradient = evaluator
8105 .evaluate_gradient(¶ms)
8106 .expect("evaluation should succeed");
8107
8108 assert_relative_eq!(value[0].re, 8.0);
8110 assert_relative_eq!(value[0].im, 0.0);
8111
8112 assert_relative_eq!(gradient[0][0].re, 4.0);
8114 assert_relative_eq!(gradient[0][0].im, 0.0);
8115 }
8116 #[test]
8117 fn test_default_build_uses_lowered_expression_runtime() {
8118 let expr = ComplexScalar::new(
8119 "opt_in_gate",
8120 parameter!("opt_in_gate_re", 2.0),
8121 parameter!("opt_in_gate_im", 0.0),
8122 )
8123 .unwrap()
8124 .norm_sqr();
8125 let dataset = Arc::new(test_dataset());
8126 let evaluator = expr.load(&dataset).unwrap();
8127
8128 let diagnostics = evaluator.expression_runtime_diagnostics();
8129 assert!(diagnostics.ir_planning_enabled);
8130 assert!(diagnostics.lowered_value_program_present);
8131 assert!(diagnostics.lowered_gradient_program_present);
8132 assert!(diagnostics.lowered_value_gradient_program_present);
8133 assert_eq!(
8134 evaluator.evaluate(&[]).expect("evaluation should succeed")[0],
8135 Complex64::new(4.0, 0.0)
8136 );
8137 }
8138
8139 #[test]
8140 fn parameter_name_only_creates_free_parameter() {
8141 let p = parameter!("mass");
8142
8143 assert_eq!(p.name(), "mass");
8144 assert_eq!(p.fixed(), None);
8145 assert_eq!(p.initial(), None);
8146 assert_eq!(p.bounds(), (None, None));
8147 assert_eq!(p.unit(), None);
8148 assert_eq!(p.latex(), None);
8149 assert_eq!(p.description(), None);
8150 assert!(p.is_free());
8151 assert!(!p.is_fixed());
8152 }
8153
8154 #[test]
8155 fn parameter_name_and_value_creates_fixed_parameter() {
8156 let p = parameter!("width", 0.15);
8157
8158 assert_eq!(p.name(), "width");
8159 assert_eq!(p.fixed(), Some(0.15));
8160 assert_eq!(p.initial(), Some(0.15));
8161 assert!(p.is_fixed());
8162 assert!(!p.is_free());
8163 }
8164
8165 #[test]
8166 fn keyword_initial_sets_initial_only() {
8167 let p = parameter!("alpha", initial: 1.25);
8168
8169 assert_eq!(p.name(), "alpha");
8170 assert_eq!(p.fixed(), None);
8171 assert_eq!(p.initial(), Some(1.25));
8172 assert_eq!(p.bounds(), (None, None));
8173 assert!(p.is_free());
8174 }
8175
8176 #[test]
8177 fn keyword_fixed_sets_fixed_and_initial() {
8178 let p = parameter!("beta", fixed: 2.5);
8179
8180 assert_eq!(p.name(), "beta");
8181 assert_eq!(p.fixed(), Some(2.5));
8182 assert_eq!(p.initial(), Some(2.5));
8183 assert!(p.is_fixed());
8184 }
8185
8186 #[test]
8187 fn bounds_accept_plain_numbers() {
8188 let p = parameter!("x", bounds: (0.0, 10.0));
8189
8190 assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
8191 }
8192
8193 #[test]
8194 fn bounds_accept_none_and_number() {
8195 let p = parameter!("x", bounds: (None, 10.0));
8196
8197 assert_eq!(p.bounds(), (None, Some(10.0)));
8198 }
8199
8200 #[test]
8201 fn bounds_accept_number_and_none() {
8202 let p = parameter!("x", bounds: (-1.0, None));
8203
8204 assert_eq!(p.bounds(), (Some(-1.0), None));
8205 }
8206
8207 #[test]
8208 fn bounds_accept_both_none() {
8209 let p = parameter!("x", bounds: (None, None));
8210
8211 assert_eq!(p.bounds(), (None, None));
8212 }
8213
8214 #[test]
8215 fn bounds_accept_arbitrary_expressions() {
8216 let lo = 1.0;
8217 let hi = 2.0 * 3.0;
8218 let p = parameter!("x", bounds: (lo - 0.5, hi));
8219
8220 assert_eq!(p.bounds(), (Some(0.5), Some(6.0)));
8221 }
8222
8223 #[test]
8224 fn multiple_keyword_arguments_work_together() {
8225 let p = parameter!(
8226 "gamma",
8227 initial: 1.0,
8228 bounds: (0.0, 5.0),
8229 unit: "GeV",
8230 latex: r"\gamma",
8231 description: "test parameter",
8232 );
8233
8234 assert_eq!(p.name(), "gamma");
8235 assert_eq!(p.fixed(), None);
8236 assert_eq!(p.initial(), Some(1.0));
8237 assert_eq!(p.bounds(), (Some(0.0), Some(5.0)));
8238 assert_eq!(p.unit().as_deref(), Some("GeV"));
8239 assert_eq!(p.latex().as_deref(), Some(r"\gamma"));
8240 assert_eq!(p.description().as_deref(), Some("test parameter"));
8241 }
8242
8243 #[test]
8244 fn fixed_can_be_combined_with_other_fields() {
8245 let p = parameter!(
8246 "delta",
8247 fixed: 3.0,
8248 bounds: (0.0, 10.0),
8249 unit: "rad",
8250 );
8251
8252 assert_eq!(p.name(), "delta");
8253 assert_eq!(p.fixed(), Some(3.0));
8254 assert_eq!(p.initial(), Some(3.0));
8255 assert_eq!(p.bounds(), (Some(0.0), Some(10.0)));
8256 assert_eq!(p.unit().as_deref(), Some("rad"));
8257 }
8258
8259 #[test]
8260 fn trailing_comma_is_accepted() {
8261 let p = parameter!(
8262 "eps",
8263 initial: 0.5,
8264 bounds: (None, 1.0),
8265 unit: "arb",
8266 );
8267
8268 assert_eq!(p.initial(), Some(0.5));
8269 assert_eq!(p.bounds(), (None, Some(1.0)));
8270 assert_eq!(p.unit().as_deref(), Some("arb"));
8271 }
8272
8273 #[test]
8274 fn test_parameter_registration() {
8275 let expr = ComplexScalar::new(
8276 "parametric",
8277 parameter!("test_param_re"),
8278 parameter!("fixed_two", 2.0),
8279 )
8280 .unwrap();
8281 let parameters = expr.free_parameters();
8282 assert_eq!(parameters.len(), 1);
8283 assert_eq!(parameters[0], "test_param_re");
8284 }
8285
8286 #[test]
8287 #[should_panic(expected = "refers to different underlying amplitudes")]
8288 fn test_duplicate_amplitude_registration() {
8289 let amp1 = ComplexScalar::new(
8290 "same_name",
8291 parameter!("dup_re1", 1.0),
8292 parameter!("dup_im1", 0.0),
8293 )
8294 .unwrap();
8295 let amp2 = ComplexScalar::new(
8296 "same_name",
8297 parameter!("dup_re2", 2.0),
8298 parameter!("dup_im2", 0.0),
8299 )
8300 .unwrap();
8301 let _expr = amp1 + amp2;
8302 }
8303
8304 #[test]
8305 fn test_tree_printing() {
8306 let amp1 = ComplexScalar::new(
8307 "parametric_1",
8308 parameter!("test_param_re_1"),
8309 parameter!("test_param_im_1"),
8310 )
8311 .unwrap();
8312 let amp2 = ComplexScalar::new(
8313 "parametric_2",
8314 parameter!("test_param_re_2"),
8315 parameter!("test_param_im_2"),
8316 )
8317 .unwrap();
8318 let expr =
8319 &1.real() + &2.conj().imag() + Expression::one() * Complex64::new(-1.4, 2.0)
8320 - Expression::zero() / 1.0
8321 + (&1 * &2).norm_sqr();
8322 assert_eq!(
8323 expr.to_string(),
8324 "+
8325├─ -
8326│ ├─ +
8327│ │ ├─ +
8328│ │ │ ├─ Re
8329│ │ │ │ └─ parametric_1(id=0)
8330│ │ │ └─ Im
8331│ │ │ └─ *
8332│ │ │ └─ parametric_2(id=1)
8333│ │ └─ ×
8334│ │ ├─ 1 (exact)
8335│ │ └─ -1.4+2i
8336│ └─ ÷
8337│ ├─ 0 (exact)
8338│ └─ 1 (exact)
8339└─ NormSqr
8340 └─ ×
8341 ├─ parametric_1(id=0)
8342 └─ parametric_2(id=1)
8343"
8344 );
8345 }
8346}