1use std::{
2 fmt::{Debug, Display},
3 sync::{
4 atomic::{AtomicU64, Ordering},
5 Arc,
6 },
7};
8
9use auto_ops::*;
10use dyn_clone::DynClone;
11use nalgebra::{ComplexField, DVector};
12use num::complex::Complex64;
13
14use parking_lot::RwLock;
15#[cfg(feature = "rayon")]
16use rayon::prelude::*;
17use serde::{Deserialize, Serialize};
18
19static AMPLITUDE_INSTANCE_COUNTER: AtomicU64 = AtomicU64::new(0);
20
21fn next_amplitude_id() -> u64 {
22 AMPLITUDE_INSTANCE_COUNTER.fetch_add(1, Ordering::Relaxed)
23}
24
25use crate::{
26 data::{Dataset, DatasetMetadata, EventData},
27 resources::{Cache, ParameterTransform, Parameters, Resources},
28 LadduError, LadduResult, ParameterID, ReadWrite,
29};
30
31#[cfg(feature = "mpi")]
32use crate::mpi::LadduMPI;
33
34#[cfg(feature = "mpi")]
35use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
36
37#[derive(Clone, Default, Serialize, Deserialize)]
39pub struct Parameter {
40 pub name: String,
42 pub fixed: Option<f64>,
44}
45
46impl Parameter {
47 pub fn free(name: impl Into<String>) -> Self {
49 Self {
50 name: name.into(),
51 fixed: None,
52 }
53 }
54
55 pub fn fixed(name: impl Into<String>, value: f64) -> Self {
57 Self {
58 name: name.into(),
59 fixed: Some(value),
60 }
61 }
62
63 pub fn uninit() -> Self {
65 Self {
66 name: String::new(),
67 fixed: None,
68 }
69 }
70
71 pub fn is_free(&self) -> bool {
73 self.fixed.is_none()
74 }
75
76 pub fn is_fixed(&self) -> bool {
78 self.fixed.is_some()
79 }
80
81 pub fn name(&self) -> &str {
83 &self.name
84 }
85}
86
87pub type ParameterLike = Parameter;
89
90pub fn parameter(name: &str) -> Parameter {
92 Parameter::free(name)
93}
94
95pub fn constant(name: &str, value: f64) -> Parameter {
97 Parameter::fixed(name, value)
98}
99
100#[macro_export]
103macro_rules! parameter {
104 ($name:expr) => {
105 $crate::amplitudes::Parameter::free($name)
106 };
107 ($name:expr, $value:expr) => {
108 $crate::amplitudes::Parameter::fixed($name, $value)
109 };
110}
111
112#[typetag::serde(tag = "type")]
120pub trait Amplitude: DynClone + Send + Sync {
121 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID>;
130 fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
136 Ok(())
137 }
138 #[allow(unused_variables)]
143 fn precompute(&self, event: &EventData, cache: &mut Cache) {}
144 #[cfg(feature = "rayon")]
146 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
147 dataset
148 .events
149 .par_iter()
150 .zip(resources.caches.par_iter_mut())
151 .for_each(|(event, cache)| {
152 self.precompute(event, cache);
153 })
154 }
155 #[cfg(not(feature = "rayon"))]
157 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
158 dataset
159 .events
160 .iter()
161 .zip(resources.caches.iter_mut())
162 .for_each(|(event, cache)| self.precompute(event, cache))
163 }
164 fn compute(&self, parameters: &Parameters, event: &EventData, cache: &Cache) -> Complex64;
173
174 fn compute_gradient(
190 &self,
191 parameters: &Parameters,
192 event: &EventData,
193 cache: &Cache,
194 gradient: &mut DVector<Complex64>,
195 ) {
196 self.central_difference_with_indices(
197 &Vec::from_iter(0..parameters.len()),
198 parameters,
199 event,
200 cache,
201 gradient,
202 )
203 }
204
205 fn central_difference_with_indices(
211 &self,
212 indices: &[usize],
213 parameters: &Parameters,
214 event: &EventData,
215 cache: &Cache,
216 gradient: &mut DVector<Complex64>,
217 ) {
218 let x = parameters.parameters.to_owned();
219 let constants = parameters.constants.to_owned();
220 let h: DVector<f64> = x
221 .iter()
222 .map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
223 .collect::<Vec<_>>()
224 .into();
225 for i in indices {
226 let mut x_plus = x.clone();
227 let mut x_minus = x.clone();
228 x_plus[*i] += h[*i];
229 x_minus[*i] -= h[*i];
230 let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
231 let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
232 gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
233 }
234 }
235
236 fn into_expression(self) -> LadduResult<Expression>
241 where
242 Self: Sized + 'static,
243 {
244 Expression::from_amplitude(Box::new(self))
245 }
246}
247
248pub fn central_difference<F: Fn(&[f64]) -> f64>(parameters: &[f64], func: F) -> DVector<f64> {
250 let mut gradient = DVector::zeros(parameters.len());
251 let x = parameters.to_owned();
252 let h: DVector<f64> = x
253 .iter()
254 .map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
255 .collect::<Vec<_>>()
256 .into();
257 for i in 0..parameters.len() {
258 let mut x_plus = x.clone();
259 let mut x_minus = x.clone();
260 x_plus[i] += h[i];
261 x_minus[i] -= h[i];
262 let f_plus = func(&x_plus);
263 let f_minus = func(&x_minus);
264 gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
265 }
266 gradient
267}
268
269dyn_clone::clone_trait_object!(Amplitude);
270
271#[derive(Debug)]
273pub struct AmplitudeValues(pub Vec<Complex64>);
274
275#[derive(Debug)]
277pub struct GradientValues(pub usize, pub Vec<DVector<Complex64>>);
278
279#[derive(Clone, Default, Debug, Serialize, Deserialize)]
282pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
283
284impl Display for AmplitudeID {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 write!(f, "{}(id={})", self.0, self.1)
287 }
288}
289
290#[allow(missing_docs)]
292#[derive(Clone, Serialize, Deserialize)]
293pub struct Expression {
294 registry: ExpressionRegistry,
295 tree: ExpressionNode,
296}
297
298impl ReadWrite for Expression {
299 fn create_null() -> Self {
300 Self {
301 registry: ExpressionRegistry::default(),
302 tree: ExpressionNode::default(),
303 }
304 }
305}
306
307#[derive(Clone, Serialize, Deserialize)]
308#[allow(missing_docs)]
309#[derive(Default)]
310pub struct ExpressionRegistry {
311 amplitudes: Vec<Box<dyn Amplitude>>,
312 amplitude_names: Vec<String>,
313 amplitude_ids: Vec<u64>,
314 resources: Resources,
315}
316
317impl ExpressionRegistry {
318 fn singleton(mut amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
319 let mut resources = Resources::default();
320 let aid = amplitude.register(&mut resources)?;
321 let amp_id = next_amplitude_id();
322 Ok(Self {
323 amplitudes: vec![amplitude],
324 amplitude_names: vec![aid.0],
325 amplitude_ids: vec![amp_id],
326 resources,
327 })
328 }
329
330 fn merge(&self, other: &Self) -> LadduResult<(Self, Vec<usize>, Vec<usize>)> {
331 let mut resources = Resources::default();
332 let mut amplitudes = Vec::new();
333 let mut amplitude_names = Vec::new();
334 let mut amplitude_ids = Vec::new();
335 let mut name_to_index = std::collections::HashMap::new();
336
337 let mut left_map = Vec::with_capacity(self.amplitudes.len());
338 for ((amp, name), amp_id) in self
339 .amplitudes
340 .iter()
341 .zip(&self.amplitude_names)
342 .zip(&self.amplitude_ids)
343 {
344 let mut cloned_amp = dyn_clone::clone_box(&**amp);
345 let aid = cloned_amp.register(&mut resources)?;
346 amplitudes.push(cloned_amp);
347 amplitude_names.push(name.clone());
348 amplitude_ids.push(*amp_id);
349 name_to_index.insert(name.clone(), aid.1);
350 left_map.push(aid.1);
351 }
352
353 let mut right_map = Vec::with_capacity(other.amplitudes.len());
354 for ((amp, name), amp_id) in other
355 .amplitudes
356 .iter()
357 .zip(&other.amplitude_names)
358 .zip(&other.amplitude_ids)
359 {
360 if let Some(existing) = name_to_index.get(name) {
361 let existing_amp_id = amplitude_ids[*existing];
362 if existing_amp_id != *amp_id {
363 return Err(LadduError::Custom(format!(
364 "Amplitude name \"{name}\" refers to different underlying amplitudes; rename to avoid conflicts"
365 )));
366 }
367 right_map.push(*existing);
368 continue;
369 }
370 let mut cloned_amp = dyn_clone::clone_box(&**amp);
371 let aid = cloned_amp.register(&mut resources)?;
372 amplitudes.push(cloned_amp);
373 amplitude_names.push(name.clone());
374 amplitude_ids.push(*amp_id);
375 name_to_index.insert(name.clone(), aid.1);
376 right_map.push(aid.1);
377 }
378
379 Ok((
380 Self {
381 amplitudes,
382 amplitude_names,
383 amplitude_ids,
384 resources,
385 },
386 left_map,
387 right_map,
388 ))
389 }
390
391 fn rebuild_with_transform(&self, transform: ParameterTransform) -> LadduResult<Self> {
392 let mut resources = Resources::with_transform(transform);
393 let mut amplitudes = Vec::new();
394 let mut amplitude_names = Vec::new();
395 let mut amplitude_ids = Vec::new();
396 for ((amp, name), amp_id) in self
397 .amplitudes
398 .iter()
399 .zip(&self.amplitude_names)
400 .zip(&self.amplitude_ids)
401 {
402 let mut cloned_amp = dyn_clone::clone_box(&**amp);
403 let aid = cloned_amp.register(&mut resources)?;
404 if aid.0 != *name {
405 return Err(LadduError::ParameterConflict {
406 name: aid.0,
407 reason: "amplitude renamed during rebuild".to_string(),
408 });
409 }
410 amplitudes.push(cloned_amp);
411 amplitude_names.push(name.clone());
412 amplitude_ids.push(*amp_id);
413 }
414 Ok(Self {
415 amplitudes,
416 amplitude_names,
417 amplitude_ids,
418 resources,
419 })
420 }
421}
422
423#[allow(missing_docs)]
425#[derive(Clone, Serialize, Deserialize, Default, Debug)]
426pub enum ExpressionNode {
427 #[default]
428 Zero,
430 One,
432 Amp(usize),
434 Add(Box<ExpressionNode>, Box<ExpressionNode>),
436 Sub(Box<ExpressionNode>, Box<ExpressionNode>),
438 Mul(Box<ExpressionNode>, Box<ExpressionNode>),
440 Div(Box<ExpressionNode>, Box<ExpressionNode>),
442 Neg(Box<ExpressionNode>),
444 Real(Box<ExpressionNode>),
446 Imag(Box<ExpressionNode>),
448 Conj(Box<ExpressionNode>),
450 NormSqr(Box<ExpressionNode>),
452}
453
454impl ExpressionNode {
455 fn remap(&self, mapping: &[usize]) -> Self {
456 match self {
457 Self::Amp(idx) => Self::Amp(mapping[*idx]),
458 Self::Add(a, b) => Self::Add(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
459 Self::Sub(a, b) => Self::Sub(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
460 Self::Mul(a, b) => Self::Mul(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
461 Self::Div(a, b) => Self::Div(Box::new(a.remap(mapping)), Box::new(b.remap(mapping))),
462 Self::Neg(a) => Self::Neg(Box::new(a.remap(mapping))),
463 Self::Real(a) => Self::Real(Box::new(a.remap(mapping))),
464 Self::Imag(a) => Self::Imag(Box::new(a.remap(mapping))),
465 Self::Conj(a) => Self::Conj(Box::new(a.remap(mapping))),
466 Self::NormSqr(a) => Self::NormSqr(Box::new(a.remap(mapping))),
467 Self::Zero => Self::Zero,
468 Self::One => Self::One,
469 }
470 }
471
472 pub fn evaluate(&self, amplitude_values: &[Complex64]) -> Complex64 {
477 match self {
478 ExpressionNode::Amp(idx) => amplitude_values[*idx],
479 ExpressionNode::Add(a, b) => {
480 a.evaluate(amplitude_values) + b.evaluate(amplitude_values)
481 }
482 ExpressionNode::Sub(a, b) => {
483 a.evaluate(amplitude_values) - b.evaluate(amplitude_values)
484 }
485 ExpressionNode::Mul(a, b) => {
486 a.evaluate(amplitude_values) * b.evaluate(amplitude_values)
487 }
488 ExpressionNode::Div(a, b) => {
489 a.evaluate(amplitude_values) / b.evaluate(amplitude_values)
490 }
491 ExpressionNode::Neg(a) => -a.evaluate(amplitude_values),
492 ExpressionNode::Real(a) => Complex64::new(a.evaluate(amplitude_values).re, 0.0),
493 ExpressionNode::Imag(a) => Complex64::new(a.evaluate(amplitude_values).im, 0.0),
494 ExpressionNode::Conj(a) => a.evaluate(amplitude_values).conj(),
495 ExpressionNode::NormSqr(a) => {
496 let value = a.evaluate(amplitude_values);
497 Complex64::new(value.norm_sqr(), 0.0)
498 }
499 ExpressionNode::Zero => Complex64::ZERO,
500 ExpressionNode::One => Complex64::ONE,
501 }
502 }
503
504 pub fn evaluate_gradient(
509 &self,
510 amplitude_values: &[Complex64],
511 gradient_values: &[DVector<Complex64>],
512 ) -> DVector<Complex64> {
513 match self {
514 ExpressionNode::Amp(idx) => gradient_values[*idx].clone(),
515 ExpressionNode::Add(a, b) => {
516 a.evaluate_gradient(amplitude_values, gradient_values)
517 + b.evaluate_gradient(amplitude_values, gradient_values)
518 }
519 ExpressionNode::Sub(a, b) => {
520 a.evaluate_gradient(amplitude_values, gradient_values)
521 - b.evaluate_gradient(amplitude_values, gradient_values)
522 }
523 ExpressionNode::Mul(a, b) => {
524 let f_a = a.evaluate(amplitude_values);
525 let f_b = b.evaluate(amplitude_values);
526 b.evaluate_gradient(amplitude_values, gradient_values)
527 .map(|g| g * f_a)
528 + a.evaluate_gradient(amplitude_values, gradient_values)
529 .map(|g| g * f_b)
530 }
531 ExpressionNode::Div(a, b) => {
532 let f_a = a.evaluate(amplitude_values);
533 let f_b = b.evaluate(amplitude_values);
534 (a.evaluate_gradient(amplitude_values, gradient_values)
535 .map(|g| g * f_b)
536 - b.evaluate_gradient(amplitude_values, gradient_values)
537 .map(|g| g * f_a))
538 / (f_b * f_b)
539 }
540 ExpressionNode::Neg(a) => -a.evaluate_gradient(amplitude_values, gradient_values),
541 ExpressionNode::Real(a) => a
542 .evaluate_gradient(amplitude_values, gradient_values)
543 .map(|g| Complex64::new(g.re, 0.0)),
544 ExpressionNode::Imag(a) => a
545 .evaluate_gradient(amplitude_values, gradient_values)
546 .map(|g| Complex64::new(g.im, 0.0)),
547 ExpressionNode::Conj(a) => a
548 .evaluate_gradient(amplitude_values, gradient_values)
549 .map(|g| g.conj()),
550 ExpressionNode::NormSqr(a) => {
551 let conj_f_a = a.evaluate(amplitude_values).conjugate();
552 a.evaluate_gradient(amplitude_values, gradient_values)
553 .map(|g| Complex64::new(2.0 * (g * conj_f_a).re, 0.0))
554 }
555 ExpressionNode::Zero | ExpressionNode::One => {
556 let max_dim = gradient_values.first().map(|g| g.len()).unwrap_or(0);
557 DVector::zeros(max_dim)
558 }
559 }
560 }
561}
562
563impl Expression {
564 pub fn from_amplitude(amplitude: Box<dyn Amplitude>) -> LadduResult<Self> {
566 let registry = ExpressionRegistry::singleton(amplitude)?;
567 Ok(Self {
568 tree: ExpressionNode::Amp(0),
569 registry,
570 })
571 }
572
573 pub fn zero() -> Self {
575 Self {
576 registry: ExpressionRegistry::default(),
577 tree: ExpressionNode::Zero,
578 }
579 }
580
581 pub fn one() -> Self {
583 Self {
584 registry: ExpressionRegistry::default(),
585 tree: ExpressionNode::One,
586 }
587 }
588
589 fn binary_op(
590 a: &Expression,
591 b: &Expression,
592 build: impl Fn(Box<ExpressionNode>, Box<ExpressionNode>) -> ExpressionNode,
593 ) -> Expression {
594 let (registry, left_map, right_map) = a
595 .registry
596 .merge(&b.registry)
597 .expect("merging expression registries should not fail");
598 let left_tree = a.tree.remap(&left_map);
599 let right_tree = b.tree.remap(&right_map);
600 Expression {
601 registry,
602 tree: build(Box::new(left_tree), Box::new(right_tree)),
603 }
604 }
605
606 fn unary_op(a: &Expression, build: impl Fn(Box<ExpressionNode>) -> ExpressionNode) -> Self {
607 Expression {
608 registry: a.registry.clone(),
609 tree: build(Box::new(a.tree.clone())),
610 }
611 }
612
613 pub fn parameters(&self) -> Vec<String> {
615 self.registry.resources.parameter_names()
616 }
617
618 pub fn free_parameters(&self) -> Vec<String> {
620 self.registry.resources.free_parameter_names()
621 }
622
623 pub fn fixed_parameters(&self) -> Vec<String> {
625 self.registry.resources.fixed_parameter_names()
626 }
627
628 pub fn n_free(&self) -> usize {
630 self.registry.resources.n_free_parameters()
631 }
632
633 pub fn n_fixed(&self) -> usize {
635 self.registry.resources.n_fixed_parameters()
636 }
637
638 pub fn n_parameters(&self) -> usize {
640 self.registry.resources.n_parameters()
641 }
642
643 fn with_transform(&self, transform: ParameterTransform) -> LadduResult<Self> {
644 let merged = self
645 .registry
646 .resources
647 .parameter_overrides
648 .merged(&transform);
649 let registry = self.registry.rebuild_with_transform(merged)?;
650 Ok(Self {
651 registry,
652 tree: self.tree.clone(),
653 })
654 }
655
656 fn assert_parameter_exists(&self, name: &str) -> LadduResult<()> {
657 if self.parameters().iter().any(|p| p == name) {
658 Ok(())
659 } else {
660 Err(LadduError::UnregisteredParameter {
661 name: name.to_string(),
662 reason: "parameter not found".to_string(),
663 })
664 }
665 }
666
667 pub fn fix(&self, name: &str, value: f64) -> LadduResult<Self> {
669 self.assert_parameter_exists(name)?;
670 let mut transform = ParameterTransform::default();
671 transform.fixed.insert(name.to_string(), value);
672 self.with_transform(transform)
673 }
674
675 pub fn free(&self, name: &str) -> LadduResult<Self> {
677 self.assert_parameter_exists(name)?;
678 let mut transform = ParameterTransform::default();
679 transform.freed.insert(name.to_string());
680 self.with_transform(transform)
681 }
682
683 pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<Self> {
685 self.assert_parameter_exists(old)?;
686 if old == new {
687 return Ok(self.clone());
688 }
689 if self.parameters().iter().any(|p| p == new) {
690 return Err(LadduError::ParameterConflict {
691 name: new.to_string(),
692 reason: "rename target already exists".to_string(),
693 });
694 }
695 let mut transform = ParameterTransform::default();
696 transform.renames.insert(old.to_string(), new.to_string());
697 self.with_transform(transform)
698 }
699
700 pub fn rename_parameters(
702 &self,
703 mapping: &std::collections::HashMap<String, String>,
704 ) -> LadduResult<Self> {
705 for old in mapping.keys() {
706 self.assert_parameter_exists(old)?;
707 }
708 let mut final_names: std::collections::HashSet<String> =
709 self.parameters().into_iter().collect();
710 for (old, new) in mapping {
711 if old == new {
712 continue;
713 }
714 final_names.remove(old);
715 if final_names.contains(new) {
716 return Err(LadduError::ParameterConflict {
717 name: new.clone(),
718 reason: "rename target already exists".to_string(),
719 });
720 }
721 final_names.insert(new.clone());
722 }
723 let mut transform = ParameterTransform::default();
724 for (old, new) in mapping {
725 transform.renames.insert(old.clone(), new.clone());
726 }
727 self.with_transform(transform)
728 }
729
730 pub fn load(&self, dataset: &Arc<Dataset>) -> LadduResult<Evaluator> {
732 let mut resources = self.registry.resources.clone();
733 let metadata = dataset.metadata();
734 resources.reserve_cache(dataset.n_events());
735 let mut amplitudes: Vec<Box<dyn Amplitude>> = self
736 .registry
737 .amplitudes
738 .iter()
739 .map(|amp| dyn_clone::clone_box(&**amp))
740 .collect();
741 {
742 for amplitude in amplitudes.iter_mut() {
743 amplitude.bind(metadata)?;
744 amplitude.precompute_all(dataset, &mut resources);
745 }
746 }
747 Ok(Evaluator {
748 amplitudes,
749 resources: Arc::new(RwLock::new(resources)),
750 dataset: dataset.clone(),
751 expression: self.tree.clone(),
752 registry: self.registry.clone(),
753 })
754 }
755
756 pub fn real(&self) -> Self {
758 Self::unary_op(self, ExpressionNode::Real)
759 }
760 pub fn imag(&self) -> Self {
762 Self::unary_op(self, ExpressionNode::Imag)
763 }
764 pub fn conj(&self) -> Self {
766 Self::unary_op(self, ExpressionNode::Conj)
767 }
768 pub fn norm_sqr(&self) -> Self {
770 Self::unary_op(self, ExpressionNode::NormSqr)
771 }
772
773 fn write_tree(
775 &self,
776 t: &ExpressionNode,
777 f: &mut std::fmt::Formatter<'_>,
778 parent_prefix: &str,
779 immediate_prefix: &str,
780 parent_suffix: &str,
781 ) -> std::fmt::Result {
782 let display_string = match t {
783 ExpressionNode::Amp(idx) => {
784 let name = self
785 .registry
786 .amplitude_names
787 .get(*idx)
788 .cloned()
789 .unwrap_or_else(|| "<unregistered>".to_string());
790 format!("{name}(id={idx})")
791 }
792 ExpressionNode::Add(_, _) => "+".to_string(),
793 ExpressionNode::Sub(_, _) => "-".to_string(),
794 ExpressionNode::Mul(_, _) => "×".to_string(),
795 ExpressionNode::Div(_, _) => "÷".to_string(),
796 ExpressionNode::Neg(_) => "-".to_string(),
797 ExpressionNode::Real(_) => "Re".to_string(),
798 ExpressionNode::Imag(_) => "Im".to_string(),
799 ExpressionNode::Conj(_) => "*".to_string(),
800 ExpressionNode::NormSqr(_) => "NormSqr".to_string(),
801 ExpressionNode::Zero => "0".to_string(),
802 ExpressionNode::One => "1".to_string(),
803 };
804 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
805 match t {
806 ExpressionNode::Amp(_) | ExpressionNode::Zero | ExpressionNode::One => {}
807 ExpressionNode::Add(a, b)
808 | ExpressionNode::Sub(a, b)
809 | ExpressionNode::Mul(a, b)
810 | ExpressionNode::Div(a, b) => {
811 let terms = [a, b];
812 let mut it = terms.iter().peekable();
813 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
814 while let Some(child) = it.next() {
815 match it.peek() {
816 Some(_) => self.write_tree(child, f, &child_prefix, "├─ ", "│ "),
817 None => self.write_tree(child, f, &child_prefix, "└─ ", " "),
818 }?;
819 }
820 }
821 ExpressionNode::Neg(a)
822 | ExpressionNode::Real(a)
823 | ExpressionNode::Imag(a)
824 | ExpressionNode::Conj(a)
825 | ExpressionNode::NormSqr(a) => {
826 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
827 self.write_tree(a, f, &child_prefix, "└─ ", " ")?;
828 }
829 }
830 Ok(())
831 }
832}
833
834impl Debug for Expression {
835 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836 self.write_tree(&self.tree, f, "", "", "")
837 }
838}
839
840impl Display for Expression {
841 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
842 self.write_tree(&self.tree, f, "", "", "")
843 }
844}
845
846#[rustfmt::skip]
847impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
848 Expression::binary_op(a, b, ExpressionNode::Add)
849});
850#[rustfmt::skip]
851impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
852 Expression::binary_op(a, b, ExpressionNode::Sub)
853});
854#[rustfmt::skip]
855impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
856 Expression::binary_op(a, b, ExpressionNode::Mul)
857});
858#[rustfmt::skip]
859impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
860 Expression::binary_op(a, b, ExpressionNode::Div)
861});
862#[rustfmt::skip]
863impl_op_ex!(- |a: &Expression| -> Expression {
864 Expression::unary_op(a, ExpressionNode::Neg)
865});
866
867#[allow(missing_docs)]
869#[derive(Clone)]
870pub struct Evaluator {
871 pub amplitudes: Vec<Box<dyn Amplitude>>,
872 pub resources: Arc<RwLock<Resources>>,
873 pub dataset: Arc<Dataset>,
874 pub expression: ExpressionNode,
875 registry: ExpressionRegistry,
876}
877
878#[allow(missing_docs)]
879impl Evaluator {
880 pub fn parameters(&self) -> Vec<String> {
883 self.resources.read().parameter_names()
884 }
885
886 pub fn free_parameters(&self) -> Vec<String> {
888 self.resources.read().free_parameter_names()
889 }
890
891 pub fn fixed_parameters(&self) -> Vec<String> {
893 self.resources.read().fixed_parameter_names()
894 }
895
896 pub fn n_free(&self) -> usize {
898 self.resources.read().n_free_parameters()
899 }
900
901 pub fn n_fixed(&self) -> usize {
903 self.resources.read().n_fixed_parameters()
904 }
905
906 pub fn n_parameters(&self) -> usize {
908 self.resources.read().n_parameters()
909 }
910
911 fn as_expression(&self) -> Expression {
912 Expression {
913 registry: self.registry.clone(),
914 tree: self.expression.clone(),
915 }
916 }
917
918 pub fn fix(&self, name: &str, value: f64) -> LadduResult<Self> {
920 self.as_expression().fix(name, value)?.load(&self.dataset)
921 }
922
923 pub fn free(&self, name: &str) -> LadduResult<Self> {
925 self.as_expression().free(name)?.load(&self.dataset)
926 }
927
928 pub fn rename_parameter(&self, old: &str, new: &str) -> LadduResult<Self> {
930 self.as_expression()
931 .rename_parameter(old, new)?
932 .load(&self.dataset)
933 }
934
935 pub fn rename_parameters(
937 &self,
938 mapping: &std::collections::HashMap<String, String>,
939 ) -> LadduResult<Self> {
940 self.as_expression()
941 .rename_parameters(mapping)?
942 .load(&self.dataset)
943 }
944
945 pub fn activate<T: AsRef<str>>(&self, name: T) {
947 self.resources.write().activate(name);
948 }
949 pub fn activate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
951 self.resources.write().activate_strict(name)
952 }
953
954 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) {
956 self.resources.write().activate_many(names);
957 }
958 pub fn activate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
960 self.resources.write().activate_many_strict(names)
961 }
962
963 pub fn activate_all(&self) {
965 self.resources.write().activate_all();
966 }
967
968 pub fn deactivate<T: AsRef<str>>(&self, name: T) {
970 self.resources.write().deactivate(name);
971 }
972
973 pub fn deactivate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
975 self.resources.write().deactivate_strict(name)
976 }
977
978 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) {
980 self.resources.write().deactivate_many(names);
981 }
982 pub fn deactivate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
984 self.resources.write().deactivate_many_strict(names)
985 }
986
987 pub fn deactivate_all(&self) {
989 self.resources.write().deactivate_all();
990 }
991
992 pub fn isolate<T: AsRef<str>>(&self, name: T) {
994 self.resources.write().isolate(name);
995 }
996
997 pub fn isolate_strict<T: AsRef<str>>(&self, name: T) -> LadduResult<()> {
999 self.resources.write().isolate_strict(name)
1000 }
1001
1002 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) {
1004 self.resources.write().isolate_many(names);
1005 }
1006
1007 pub fn isolate_many_strict<T: AsRef<str>>(&self, names: &[T]) -> LadduResult<()> {
1009 self.resources.write().isolate_many_strict(names)
1010 }
1011
1012 pub fn evaluate_local(&self, parameters: &[f64]) -> Vec<Complex64> {
1020 let resources = self.resources.read();
1021 let parameters = Parameters::new(parameters, &resources.constants);
1022 #[cfg(feature = "rayon")]
1023 {
1024 self.dataset
1025 .events
1026 .par_iter()
1027 .zip(resources.caches.par_iter())
1028 .map(|(event, cache)| {
1029 let amplitude_values: Vec<Complex64> = self
1030 .amplitudes
1031 .iter()
1032 .zip(resources.active.iter())
1033 .map(|(amp, active)| {
1034 if *active {
1035 amp.compute(¶meters, event, cache)
1036 } else {
1037 Complex64::ZERO
1038 }
1039 })
1040 .collect();
1041 self.expression.evaluate(&litude_values)
1042 })
1043 .collect()
1044 }
1045 #[cfg(not(feature = "rayon"))]
1046 {
1047 self.dataset
1048 .events
1049 .iter()
1050 .zip(resources.caches.iter())
1051 .map(|(event, cache)| {
1052 let amplitude_values: Vec<Complex64> = self
1053 .amplitudes
1054 .iter()
1055 .zip(resources.active.iter())
1056 .map(|(amp, active)| {
1057 if *active {
1058 amp.compute(¶meters, event, cache)
1059 } else {
1060 Complex64::ZERO
1061 }
1062 })
1063 .collect();
1064 self.expression.evaluate(&litude_values)
1065 })
1066 .collect()
1067 }
1068 }
1069
1070 #[cfg(feature = "mpi")]
1078 fn evaluate_mpi(&self, parameters: &[f64], world: &SimpleCommunicator) -> Vec<Complex64> {
1079 let local_evaluation = self.evaluate_local(parameters);
1080 let n_events = self.dataset.n_events();
1081 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events];
1082 let (counts, displs) = world.get_counts_displs(n_events);
1083 {
1084 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
1085 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
1086 }
1087 buffer
1088 }
1089
1090 pub fn evaluate(&self, parameters: &[f64]) -> Vec<Complex64> {
1093 #[cfg(feature = "mpi")]
1094 {
1095 if let Some(world) = crate::mpi::get_world() {
1096 return self.evaluate_mpi(parameters, &world);
1097 }
1098 }
1099 self.evaluate_local(parameters)
1100 }
1101
1102 pub fn evaluate_batch_local(&self, parameters: &[f64], indices: &[usize]) -> Vec<Complex64> {
1105 let resources = self.resources.read();
1106 let parameters = Parameters::new(parameters, &resources.constants);
1107 #[cfg(feature = "rayon")]
1108 {
1109 indices
1110 .par_iter()
1111 .map(|&idx| {
1112 let event = &self.dataset.events[idx];
1113 let cache = &resources.caches[idx];
1114 let amplitude_values: Vec<Complex64> = self
1115 .amplitudes
1116 .iter()
1117 .zip(resources.active.iter())
1118 .map(|(amp, active)| {
1119 if *active {
1120 amp.compute(¶meters, event, cache)
1121 } else {
1122 Complex64::ZERO
1123 }
1124 })
1125 .collect();
1126 self.expression.evaluate(&litude_values)
1127 })
1128 .collect()
1129 }
1130 #[cfg(not(feature = "rayon"))]
1131 {
1132 indices
1133 .iter()
1134 .map(|&idx| {
1135 let event = &self.dataset.events[idx];
1136 let cache = &resources.caches[idx];
1137 let amplitude_values: Vec<Complex64> = self
1138 .amplitudes
1139 .iter()
1140 .zip(resources.active.iter())
1141 .map(|(amp, active)| {
1142 if *active {
1143 amp.compute(¶meters, event, cache)
1144 } else {
1145 Complex64::ZERO
1146 }
1147 })
1148 .collect();
1149 self.expression.evaluate(&litude_values)
1150 })
1151 .collect()
1152 }
1153 }
1154
1155 #[cfg(feature = "mpi")]
1158 fn evaluate_batch_mpi(
1159 &self,
1160 parameters: &[f64],
1161 indices: &[usize],
1162 world: &SimpleCommunicator,
1163 ) -> Vec<Complex64> {
1164 let total = self.dataset.n_events();
1165 let locals = world.locals_from_globals(indices, total);
1166 let local_evaluation = self.evaluate_batch_local(parameters, &locals);
1167 world.all_gather_batched_partitioned(&local_evaluation, indices, total, None)
1168 }
1169
1170 pub fn evaluate_batch(&self, parameters: &[f64], indices: &[usize]) -> Vec<Complex64> {
1173 #[cfg(feature = "mpi")]
1174 {
1175 if let Some(world) = crate::mpi::get_world() {
1176 return self.evaluate_batch_mpi(parameters, indices, &world);
1177 }
1178 }
1179 self.evaluate_batch_local(parameters, indices)
1180 }
1181
1182 pub fn evaluate_gradient_local(&self, parameters: &[f64]) -> Vec<DVector<Complex64>> {
1190 let resources = self.resources.read();
1191 let parameters = Parameters::new(parameters, &resources.constants);
1192 #[cfg(feature = "rayon")]
1193 {
1194 self.dataset
1195 .events
1196 .par_iter()
1197 .zip(resources.caches.par_iter())
1198 .map(|(event, cache)| {
1199 let mut gradient_values =
1200 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1201 self.amplitudes
1202 .iter()
1203 .zip(resources.active.iter())
1204 .zip(gradient_values.iter_mut())
1205 .for_each(|((amp, active), grad)| {
1206 if *active {
1207 amp.compute_gradient(¶meters, event, cache, grad)
1208 }
1209 });
1210 let amplitude_values: Vec<Complex64> = self
1211 .amplitudes
1212 .iter()
1213 .zip(resources.active.iter())
1214 .map(|(amp, active)| {
1215 if *active {
1216 amp.compute(¶meters, event, cache)
1217 } else {
1218 Complex64::ZERO
1219 }
1220 })
1221 .collect();
1222 self.expression
1223 .evaluate_gradient(&litude_values, &gradient_values)
1224 })
1225 .collect()
1226 }
1227 #[cfg(not(feature = "rayon"))]
1228 {
1229 self.dataset
1230 .events
1231 .iter()
1232 .zip(resources.caches.iter())
1233 .map(|(event, cache)| {
1234 let mut gradient_values =
1235 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1236 self.amplitudes
1237 .iter()
1238 .zip(resources.active.iter())
1239 .zip(gradient_values.iter_mut())
1240 .for_each(|((amp, active), grad)| {
1241 if *active {
1242 amp.compute_gradient(¶meters, event, cache, grad)
1243 }
1244 });
1245 let amplitude_values: Vec<Complex64> = self
1246 .amplitudes
1247 .iter()
1248 .zip(resources.active.iter())
1249 .map(|(amp, active)| {
1250 if *active {
1251 amp.compute(¶meters, event, cache)
1252 } else {
1253 Complex64::ZERO
1254 }
1255 })
1256 .collect();
1257
1258 self.expression
1259 .evaluate_gradient(&litude_values, &gradient_values)
1260 })
1261 .collect()
1262 }
1263 }
1264
1265 #[cfg(feature = "mpi")]
1273 fn evaluate_gradient_mpi(
1274 &self,
1275 parameters: &[f64],
1276 world: &SimpleCommunicator,
1277 ) -> Vec<DVector<Complex64>> {
1278 let local_evaluation = self.evaluate_gradient_local(parameters);
1279 let n_events = self.dataset.n_events();
1280 let mut buffer: Vec<Complex64> = vec![Complex64::ZERO; n_events * parameters.len()];
1281 let (counts, displs) = world.get_counts_displs(n_events);
1282 {
1283 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
1284 world.all_gather_varcount_into(
1285 &local_evaluation
1286 .iter()
1287 .flat_map(|v| v.data.as_vec())
1288 .copied()
1289 .collect::<Vec<_>>(),
1290 &mut partitioned_buffer,
1291 );
1292 }
1293 buffer
1294 .chunks(parameters.len())
1295 .map(|chunk| DVector::from_row_slice(chunk))
1296 .collect()
1297 }
1298
1299 pub fn evaluate_gradient(&self, parameters: &[f64]) -> Vec<DVector<Complex64>> {
1302 #[cfg(feature = "mpi")]
1303 {
1304 if let Some(world) = crate::mpi::get_world() {
1305 return self.evaluate_gradient_mpi(parameters, &world);
1306 }
1307 }
1308 self.evaluate_gradient_local(parameters)
1309 }
1310
1311 pub fn evaluate_gradient_batch_local(
1314 &self,
1315 parameters: &[f64],
1316 indices: &[usize],
1317 ) -> Vec<DVector<Complex64>> {
1318 let resources = self.resources.read();
1319 let parameters = Parameters::new(parameters, &resources.constants);
1320 #[cfg(feature = "rayon")]
1321 {
1322 indices
1323 .par_iter()
1324 .map(|&idx| {
1325 let event = &self.dataset.events[idx];
1326 let cache = &resources.caches[idx];
1327 let mut gradient_values =
1328 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1329 self.amplitudes
1330 .iter()
1331 .zip(resources.active.iter())
1332 .zip(gradient_values.iter_mut())
1333 .for_each(|((amp, active), grad)| {
1334 if *active {
1335 amp.compute_gradient(¶meters, event, cache, grad)
1336 }
1337 });
1338 let amplitude_values: Vec<Complex64> = self
1339 .amplitudes
1340 .iter()
1341 .zip(resources.active.iter())
1342 .map(|(amp, active)| {
1343 if *active {
1344 amp.compute(¶meters, event, cache)
1345 } else {
1346 Complex64::ZERO
1347 }
1348 })
1349 .collect();
1350 self.expression
1351 .evaluate_gradient(&litude_values, &gradient_values)
1352 })
1353 .collect()
1354 }
1355 #[cfg(not(feature = "rayon"))]
1356 {
1357 indices
1358 .iter()
1359 .map(|&idx| {
1360 let event = &self.dataset.events[idx];
1361 let cache = &resources.caches[idx];
1362 let mut gradient_values =
1363 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1364 self.amplitudes
1365 .iter()
1366 .zip(resources.active.iter())
1367 .zip(gradient_values.iter_mut())
1368 .for_each(|((amp, active), grad)| {
1369 if *active {
1370 amp.compute_gradient(¶meters, event, cache, grad)
1371 }
1372 });
1373 let amplitude_values: Vec<Complex64> = self
1374 .amplitudes
1375 .iter()
1376 .zip(resources.active.iter())
1377 .map(|(amp, active)| {
1378 if *active {
1379 amp.compute(¶meters, event, cache)
1380 } else {
1381 Complex64::ZERO
1382 }
1383 })
1384 .collect();
1385
1386 self.expression
1387 .evaluate_gradient(&litude_values, &gradient_values)
1388 })
1389 .collect()
1390 }
1391 }
1392
1393 #[cfg(feature = "mpi")]
1396 fn evaluate_gradient_batch_mpi(
1397 &self,
1398 parameters: &[f64],
1399 indices: &[usize],
1400 world: &SimpleCommunicator,
1401 ) -> Vec<DVector<Complex64>> {
1402 let total = self.dataset.n_events();
1403 let locals = world.locals_from_globals(indices, total);
1404 let flattened_local_evaluation = self
1405 .evaluate_gradient_batch_local(parameters, &locals)
1406 .iter()
1407 .flat_map(|g| g.data.as_vec().to_vec())
1408 .collect::<Vec<Complex64>>();
1409 world
1410 .all_gather_batched_partitioned(
1411 &flattened_local_evaluation,
1412 indices,
1413 total,
1414 Some(parameters.len()),
1415 )
1416 .chunks(parameters.len())
1417 .map(DVector::from_row_slice)
1418 .collect()
1419 }
1420
1421 pub fn evaluate_gradient_batch(
1425 &self,
1426 parameters: &[f64],
1427 indices: &[usize],
1428 ) -> Vec<DVector<Complex64>> {
1429 #[cfg(feature = "mpi")]
1430 {
1431 if let Some(world) = crate::mpi::get_world() {
1432 return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
1433 }
1434 }
1435 self.evaluate_gradient_batch_local(parameters, indices)
1436 }
1437}
1438
1439#[derive(Clone, Serialize, Deserialize)]
1441pub struct TestAmplitude {
1442 name: String,
1443 re: ParameterLike,
1444 pid_re: ParameterID,
1445 im: ParameterLike,
1446 pid_im: ParameterID,
1447}
1448
1449impl TestAmplitude {
1450 #[allow(clippy::new_ret_no_self)]
1452 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> LadduResult<Expression> {
1453 Self {
1454 name: name.to_string(),
1455 re,
1456 pid_re: Default::default(),
1457 im,
1458 pid_im: Default::default(),
1459 }
1460 .into_expression()
1461 }
1462}
1463
1464#[typetag::serde]
1465impl Amplitude for TestAmplitude {
1466 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
1467 self.pid_re = resources.register_parameter(&self.re)?;
1468 self.pid_im = resources.register_parameter(&self.im)?;
1469 resources.register_amplitude(&self.name)
1470 }
1471
1472 fn compute(&self, parameters: &Parameters, event: &EventData, _cache: &Cache) -> Complex64 {
1473 Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im)) * event.p4s[0].e()
1474 }
1475
1476 fn compute_gradient(
1477 &self,
1478 _parameters: &Parameters,
1479 event: &EventData,
1480 _cache: &Cache,
1481 gradient: &mut DVector<Complex64>,
1482 ) {
1483 if let ParameterID::Parameter(ind) = self.pid_re {
1484 gradient[ind] = Complex64::ONE * event.p4s[0].e();
1485 }
1486 if let ParameterID::Parameter(ind) = self.pid_im {
1487 gradient[ind] = Complex64::I * event.p4s[0].e();
1488 }
1489 }
1490}
1491
1492#[cfg(test)]
1493mod tests {
1494 use crate::data::{test_dataset, test_event, DatasetMetadata};
1495
1496 use super::*;
1497 use crate::{
1498 data::EventData,
1499 resources::{Cache, ParameterID, Parameters, Resources},
1500 };
1501 use approx::assert_relative_eq;
1502 use serde::{Deserialize, Serialize};
1503
1504 #[derive(Clone, Serialize, Deserialize)]
1505 pub struct ComplexScalar {
1506 name: String,
1507 re: ParameterLike,
1508 pid_re: ParameterID,
1509 im: ParameterLike,
1510 pid_im: ParameterID,
1511 }
1512
1513 impl ComplexScalar {
1514 #[allow(clippy::new_ret_no_self)]
1515 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> LadduResult<Expression> {
1516 Self {
1517 name: name.to_string(),
1518 re,
1519 pid_re: Default::default(),
1520 im,
1521 pid_im: Default::default(),
1522 }
1523 .into_expression()
1524 }
1525 }
1526
1527 #[typetag::serde]
1528 impl Amplitude for ComplexScalar {
1529 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
1530 self.pid_re = resources.register_parameter(&self.re)?;
1531 self.pid_im = resources.register_parameter(&self.im)?;
1532 resources.register_amplitude(&self.name)
1533 }
1534
1535 fn compute(
1536 &self,
1537 parameters: &Parameters,
1538 _event: &EventData,
1539 _cache: &Cache,
1540 ) -> Complex64 {
1541 Complex64::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
1542 }
1543
1544 fn compute_gradient(
1545 &self,
1546 _parameters: &Parameters,
1547 _event: &EventData,
1548 _cache: &Cache,
1549 gradient: &mut DVector<Complex64>,
1550 ) {
1551 if let ParameterID::Parameter(ind) = self.pid_re {
1552 gradient[ind] = Complex64::ONE;
1553 }
1554 if let ParameterID::Parameter(ind) = self.pid_im {
1555 gradient[ind] = Complex64::I;
1556 }
1557 }
1558 }
1559
1560 #[test]
1561 fn test_batch_evaluation() {
1562 let expr = TestAmplitude::new("test", parameter("real"), parameter("imag")).unwrap();
1563 let mut event1 = test_event();
1564 event1.p4s[0].t = 10.0;
1565 let mut event2 = test_event();
1566 event2.p4s[0].t = 11.0;
1567 let mut event3 = test_event();
1568 event3.p4s[0].t = 12.0;
1569 let dataset = Arc::new(Dataset::new_with_metadata(
1570 vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
1571 Arc::new(DatasetMetadata::default()),
1572 ));
1573 let evaluator = expr.load(&dataset).unwrap();
1574 let result = evaluator.evaluate_batch(&[1.1, 2.2], &[0, 2]);
1575 assert_eq!(result.len(), 2);
1576 assert_eq!(result[0], Complex64::new(1.1, 2.2) * 10.0);
1577 assert_eq!(result[1], Complex64::new(1.1, 2.2) * 12.0);
1578 let result_grad = evaluator.evaluate_gradient_batch(&[1.1, 2.2], &[0, 2]);
1579 assert_eq!(result_grad.len(), 2);
1580 assert_eq!(result_grad[0][0], Complex64::new(10.0, 0.0));
1581 assert_eq!(result_grad[0][1], Complex64::new(0.0, 10.0));
1582 assert_eq!(result_grad[1][0], Complex64::new(12.0, 0.0));
1583 assert_eq!(result_grad[1][1], Complex64::new(0.0, 12.0));
1584 }
1585
1586 #[test]
1587 fn test_constant_amplitude() {
1588 let expr = ComplexScalar::new(
1589 "constant",
1590 constant("const_re", 2.0),
1591 constant("const_im", 3.0),
1592 )
1593 .unwrap();
1594 let dataset = Arc::new(Dataset::new_with_metadata(
1595 vec![Arc::new(test_event())],
1596 Arc::new(DatasetMetadata::default()),
1597 ));
1598 let evaluator = expr.load(&dataset).unwrap();
1599 let result = evaluator.evaluate(&[]);
1600 assert_eq!(result[0], Complex64::new(2.0, 3.0));
1601 }
1602
1603 #[test]
1604 fn test_parametric_amplitude() {
1605 let expr = ComplexScalar::new(
1606 "parametric",
1607 parameter("test_param_re"),
1608 parameter("test_param_im"),
1609 )
1610 .unwrap();
1611 let dataset = Arc::new(test_dataset());
1612 let evaluator = expr.load(&dataset).unwrap();
1613 let result = evaluator.evaluate(&[2.0, 3.0]);
1614 assert_eq!(result[0], Complex64::new(2.0, 3.0));
1615 }
1616
1617 #[test]
1618 fn test_expression_operations() {
1619 let expr1 = ComplexScalar::new(
1620 "const1",
1621 constant("const1_re", 2.0),
1622 constant("const1_im", 0.0),
1623 )
1624 .unwrap();
1625 let expr2 = ComplexScalar::new(
1626 "const2",
1627 constant("const2_re", 0.0),
1628 constant("const2_im", 1.0),
1629 )
1630 .unwrap();
1631 let expr3 = ComplexScalar::new(
1632 "const3",
1633 constant("const3_re", 3.0),
1634 constant("const3_im", 4.0),
1635 )
1636 .unwrap();
1637
1638 let dataset = Arc::new(test_dataset());
1639
1640 let expr_add = &expr1 + &expr2;
1642 let result_add = expr_add.load(&dataset).unwrap().evaluate(&[]);
1643 assert_eq!(result_add[0], Complex64::new(2.0, 1.0));
1644
1645 let expr_sub = &expr1 - &expr2;
1647 let result_sub = expr_sub.load(&dataset).unwrap().evaluate(&[]);
1648 assert_eq!(result_sub[0], Complex64::new(2.0, -1.0));
1649
1650 let expr_mul = &expr1 * &expr2;
1652 let result_mul = expr_mul.load(&dataset).unwrap().evaluate(&[]);
1653 assert_eq!(result_mul[0], Complex64::new(0.0, 2.0));
1654
1655 let expr_div = &expr1 / &expr3;
1657 let result_div = expr_div.load(&dataset).unwrap().evaluate(&[]);
1658 assert_eq!(result_div[0], Complex64::new(6.0 / 25.0, -8.0 / 25.0));
1659
1660 let expr_neg = -&expr3;
1662 let result_neg = expr_neg.load(&dataset).unwrap().evaluate(&[]);
1663 assert_eq!(result_neg[0], Complex64::new(-3.0, -4.0));
1664
1665 let expr_add2 = &expr_add + &expr_mul;
1667 let result_add2 = expr_add2.load(&dataset).unwrap().evaluate(&[]);
1668 assert_eq!(result_add2[0], Complex64::new(2.0, 3.0));
1669
1670 let expr_sub2 = &expr_add - &expr_mul;
1672 let result_sub2 = expr_sub2.load(&dataset).unwrap().evaluate(&[]);
1673 assert_eq!(result_sub2[0], Complex64::new(2.0, -1.0));
1674
1675 let expr_mul2 = &expr_add * &expr_mul;
1677 let result_mul2 = expr_mul2.load(&dataset).unwrap().evaluate(&[]);
1678 assert_eq!(result_mul2[0], Complex64::new(-2.0, 4.0));
1679
1680 let expr_div2 = &expr_add / &expr_add2;
1682 let result_div2 = expr_div2.load(&dataset).unwrap().evaluate(&[]);
1683 assert_eq!(result_div2[0], Complex64::new(7.0 / 13.0, -4.0 / 13.0));
1684
1685 let expr_neg2 = -&expr_mul2;
1687 let result_neg2 = expr_neg2.load(&dataset).unwrap().evaluate(&[]);
1688 assert_eq!(result_neg2[0], Complex64::new(2.0, -4.0));
1689
1690 let expr_real = expr3.real();
1692 let result_real = expr_real.load(&dataset).unwrap().evaluate(&[]);
1693 assert_eq!(result_real[0], Complex64::new(3.0, 0.0));
1694
1695 let expr_mul2_real = expr_mul2.real();
1697 let result_mul2_real = expr_mul2_real.load(&dataset).unwrap().evaluate(&[]);
1698 assert_eq!(result_mul2_real[0], Complex64::new(-2.0, 0.0));
1699
1700 let expr_imag = expr3.imag();
1702 let result_imag = expr_imag.load(&dataset).unwrap().evaluate(&[]);
1703 assert_eq!(result_imag[0], Complex64::new(4.0, 0.0));
1704
1705 let expr_mul2_imag = expr_mul2.imag();
1707 let result_mul2_imag = expr_mul2_imag.load(&dataset).unwrap().evaluate(&[]);
1708 assert_eq!(result_mul2_imag[0], Complex64::new(4.0, 0.0));
1709
1710 let expr_conj = expr3.conj();
1712 let result_conj = expr_conj.load(&dataset).unwrap().evaluate(&[]);
1713 assert_eq!(result_conj[0], Complex64::new(3.0, -4.0));
1714
1715 let expr_mul2_conj = expr_mul2.conj();
1717 let result_mul2_conj = expr_mul2_conj.load(&dataset).unwrap().evaluate(&[]);
1718 assert_eq!(result_mul2_conj[0], Complex64::new(-2.0, -4.0));
1719
1720 let expr_norm = expr1.norm_sqr();
1722 let result_norm = expr_norm.load(&dataset).unwrap().evaluate(&[]);
1723 assert_eq!(result_norm[0], Complex64::new(4.0, 0.0));
1724
1725 let expr_mul2_norm = expr_mul2.norm_sqr();
1727 let result_mul2_norm = expr_mul2_norm.load(&dataset).unwrap().evaluate(&[]);
1728 assert_eq!(result_mul2_norm[0], Complex64::new(20.0, 0.0));
1729 }
1730
1731 #[test]
1732 fn test_amplitude_activation() {
1733 let expr1 = ComplexScalar::new(
1734 "const1",
1735 constant("const1_re_act", 1.0),
1736 constant("const1_im_act", 0.0),
1737 )
1738 .unwrap();
1739 let expr2 = ComplexScalar::new(
1740 "const2",
1741 constant("const2_re_act", 2.0),
1742 constant("const2_im_act", 0.0),
1743 )
1744 .unwrap();
1745
1746 let dataset = Arc::new(test_dataset());
1747 let expr = &expr1 + &expr2;
1748 let evaluator = expr.load(&dataset).unwrap();
1749
1750 let result = evaluator.evaluate(&[]);
1752 assert_eq!(result[0], Complex64::new(3.0, 0.0));
1753
1754 evaluator.deactivate_strict("const1").unwrap();
1756 let result = evaluator.evaluate(&[]);
1757 assert_eq!(result[0], Complex64::new(2.0, 0.0));
1758
1759 evaluator.isolate_strict("const1").unwrap();
1761 let result = evaluator.evaluate(&[]);
1762 assert_eq!(result[0], Complex64::new(1.0, 0.0));
1763
1764 evaluator.activate_all();
1766 let result = evaluator.evaluate(&[]);
1767 assert_eq!(result[0], Complex64::new(3.0, 0.0));
1768 }
1769
1770 #[test]
1771 fn test_gradient() {
1772 let expr1 = ComplexScalar::new(
1773 "parametric_1",
1774 parameter("test_param_re_1"),
1775 parameter("test_param_im_1"),
1776 )
1777 .unwrap();
1778 let expr2 = ComplexScalar::new(
1779 "parametric_2",
1780 parameter("test_param_re_2"),
1781 parameter("test_param_im_2"),
1782 )
1783 .unwrap();
1784
1785 let dataset = Arc::new(test_dataset());
1786 let params = vec![2.0, 3.0, 4.0, 5.0];
1787
1788 let expr = &expr1 + &expr2;
1789 let evaluator = expr.load(&dataset).unwrap();
1790
1791 let gradient = evaluator.evaluate_gradient(¶ms);
1792
1793 assert_relative_eq!(gradient[0][0].re, 1.0);
1794 assert_relative_eq!(gradient[0][0].im, 0.0);
1795 assert_relative_eq!(gradient[0][1].re, 0.0);
1796 assert_relative_eq!(gradient[0][1].im, 1.0);
1797 assert_relative_eq!(gradient[0][2].re, 1.0);
1798 assert_relative_eq!(gradient[0][2].im, 0.0);
1799 assert_relative_eq!(gradient[0][3].re, 0.0);
1800 assert_relative_eq!(gradient[0][3].im, 1.0);
1801
1802 let expr = &expr1 - &expr2;
1803 let evaluator = expr.load(&dataset).unwrap();
1804
1805 let gradient = evaluator.evaluate_gradient(¶ms);
1806
1807 assert_relative_eq!(gradient[0][0].re, 1.0);
1808 assert_relative_eq!(gradient[0][0].im, 0.0);
1809 assert_relative_eq!(gradient[0][1].re, 0.0);
1810 assert_relative_eq!(gradient[0][1].im, 1.0);
1811 assert_relative_eq!(gradient[0][2].re, -1.0);
1812 assert_relative_eq!(gradient[0][2].im, 0.0);
1813 assert_relative_eq!(gradient[0][3].re, 0.0);
1814 assert_relative_eq!(gradient[0][3].im, -1.0);
1815
1816 let expr = &expr1 * &expr2;
1817 let evaluator = expr.load(&dataset).unwrap();
1818
1819 let gradient = evaluator.evaluate_gradient(¶ms);
1820
1821 assert_relative_eq!(gradient[0][0].re, 4.0);
1822 assert_relative_eq!(gradient[0][0].im, 5.0);
1823 assert_relative_eq!(gradient[0][1].re, -5.0);
1824 assert_relative_eq!(gradient[0][1].im, 4.0);
1825 assert_relative_eq!(gradient[0][2].re, 2.0);
1826 assert_relative_eq!(gradient[0][2].im, 3.0);
1827 assert_relative_eq!(gradient[0][3].re, -3.0);
1828 assert_relative_eq!(gradient[0][3].im, 2.0);
1829
1830 let expr = &expr1 / &expr2;
1831 let evaluator = expr.load(&dataset).unwrap();
1832
1833 let gradient = evaluator.evaluate_gradient(¶ms);
1834
1835 assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
1836 assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
1837 assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
1838 assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
1839 assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
1840 assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
1841 assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
1842 assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
1843
1844 let expr = -(&expr1 * &expr2);
1845 let evaluator = expr.load(&dataset).unwrap();
1846
1847 let gradient = evaluator.evaluate_gradient(¶ms);
1848
1849 assert_relative_eq!(gradient[0][0].re, -4.0);
1850 assert_relative_eq!(gradient[0][0].im, -5.0);
1851 assert_relative_eq!(gradient[0][1].re, 5.0);
1852 assert_relative_eq!(gradient[0][1].im, -4.0);
1853 assert_relative_eq!(gradient[0][2].re, -2.0);
1854 assert_relative_eq!(gradient[0][2].im, -3.0);
1855 assert_relative_eq!(gradient[0][3].re, 3.0);
1856 assert_relative_eq!(gradient[0][3].im, -2.0);
1857
1858 let expr = (&expr1 * &expr2).real();
1859 let evaluator = expr.load(&dataset).unwrap();
1860
1861 let gradient = evaluator.evaluate_gradient(¶ms);
1862
1863 assert_relative_eq!(gradient[0][0].re, 4.0);
1864 assert_relative_eq!(gradient[0][0].im, 0.0);
1865 assert_relative_eq!(gradient[0][1].re, -5.0);
1866 assert_relative_eq!(gradient[0][1].im, 0.0);
1867 assert_relative_eq!(gradient[0][2].re, 2.0);
1868 assert_relative_eq!(gradient[0][2].im, 0.0);
1869 assert_relative_eq!(gradient[0][3].re, -3.0);
1870 assert_relative_eq!(gradient[0][3].im, 0.0);
1871
1872 let expr = (&expr1 * &expr2).imag();
1873 let evaluator = expr.load(&dataset).unwrap();
1874
1875 let gradient = evaluator.evaluate_gradient(¶ms);
1876
1877 assert_relative_eq!(gradient[0][0].re, 5.0);
1878 assert_relative_eq!(gradient[0][0].im, 0.0);
1879 assert_relative_eq!(gradient[0][1].re, 4.0);
1880 assert_relative_eq!(gradient[0][1].im, 0.0);
1881 assert_relative_eq!(gradient[0][2].re, 3.0);
1882 assert_relative_eq!(gradient[0][2].im, 0.0);
1883 assert_relative_eq!(gradient[0][3].re, 2.0);
1884 assert_relative_eq!(gradient[0][3].im, 0.0);
1885
1886 let expr = (&expr1 * &expr2).conj();
1887 let evaluator = expr.load(&dataset).unwrap();
1888
1889 let gradient = evaluator.evaluate_gradient(¶ms);
1890
1891 assert_relative_eq!(gradient[0][0].re, 4.0);
1892 assert_relative_eq!(gradient[0][0].im, -5.0);
1893 assert_relative_eq!(gradient[0][1].re, -5.0);
1894 assert_relative_eq!(gradient[0][1].im, -4.0);
1895 assert_relative_eq!(gradient[0][2].re, 2.0);
1896 assert_relative_eq!(gradient[0][2].im, -3.0);
1897 assert_relative_eq!(gradient[0][3].re, -3.0);
1898 assert_relative_eq!(gradient[0][3].im, -2.0);
1899
1900 let expr = (&expr1 * &expr2).norm_sqr();
1901 let evaluator = expr.load(&dataset).unwrap();
1902
1903 let gradient = evaluator.evaluate_gradient(¶ms);
1904
1905 assert_relative_eq!(gradient[0][0].re, 164.0);
1906 assert_relative_eq!(gradient[0][0].im, 0.0);
1907 assert_relative_eq!(gradient[0][1].re, 246.0);
1908 assert_relative_eq!(gradient[0][1].im, 0.0);
1909 assert_relative_eq!(gradient[0][2].re, 104.0);
1910 assert_relative_eq!(gradient[0][2].im, 0.0);
1911 assert_relative_eq!(gradient[0][3].re, 130.0);
1912 assert_relative_eq!(gradient[0][3].im, 0.0);
1913 }
1914
1915 #[test]
1916 fn test_zeros_and_ones() {
1917 let amp = ComplexScalar::new(
1918 "parametric",
1919 parameter("test_param_re"),
1920 constant("fixed_two", 2.0),
1921 )
1922 .unwrap();
1923 let dataset = Arc::new(test_dataset());
1924 let expr = (amp * Expression::one() + Expression::zero()).norm_sqr();
1925 let evaluator = expr.load(&dataset).unwrap();
1926
1927 let params = vec![2.0];
1928 let value = evaluator.evaluate(¶ms);
1929 let gradient = evaluator.evaluate_gradient(¶ms);
1930
1931 assert_relative_eq!(value[0].re, 8.0);
1933 assert_relative_eq!(value[0].im, 0.0);
1934
1935 assert_relative_eq!(gradient[0][0].re, 4.0);
1937 assert_relative_eq!(gradient[0][0].im, 0.0);
1938 }
1939
1940 #[test]
1941 fn test_parameter_registration() {
1942 let expr = ComplexScalar::new(
1943 "parametric",
1944 parameter("test_param_re"),
1945 constant("fixed_two", 2.0),
1946 )
1947 .unwrap();
1948 let parameters = expr.free_parameters();
1949 assert_eq!(parameters.len(), 1);
1950 assert_eq!(parameters[0], "test_param_re");
1951 }
1952
1953 #[test]
1954 #[should_panic(expected = "refers to different underlying amplitudes")]
1955 fn test_duplicate_amplitude_registration() {
1956 let amp1 = ComplexScalar::new(
1957 "same_name",
1958 constant("dup_re1", 1.0),
1959 constant("dup_im1", 0.0),
1960 )
1961 .unwrap();
1962 let amp2 = ComplexScalar::new(
1963 "same_name",
1964 constant("dup_re2", 2.0),
1965 constant("dup_im2", 0.0),
1966 )
1967 .unwrap();
1968 let _expr = amp1 + amp2;
1969 }
1970
1971 #[test]
1972 fn test_tree_printing() {
1973 let amp1 = ComplexScalar::new(
1974 "parametric_1",
1975 parameter("test_param_re_1"),
1976 parameter("test_param_im_1"),
1977 )
1978 .unwrap();
1979 let amp2 = ComplexScalar::new(
1980 "parametric_2",
1981 parameter("test_param_re_2"),
1982 parameter("test_param_im_2"),
1983 )
1984 .unwrap();
1985 let expr = &1.real() + &2.conj().imag() + Expression::one() * -Expression::zero()
1986 - Expression::zero() / Expression::one()
1987 + (&1 * &2).norm_sqr();
1988 assert_eq!(
1989 expr.to_string(),
1990 "+
1991├─ -
1992│ ├─ +
1993│ │ ├─ +
1994│ │ │ ├─ Re
1995│ │ │ │ └─ parametric_1(id=0)
1996│ │ │ └─ Im
1997│ │ │ └─ *
1998│ │ │ └─ parametric_2(id=1)
1999│ │ └─ ×
2000│ │ ├─ 1
2001│ │ └─ -
2002│ │ └─ 0
2003│ └─ ÷
2004│ ├─ 0
2005│ └─ 1
2006└─ NormSqr
2007 └─ ×
2008 ├─ parametric_1(id=0)
2009 └─ parametric_2(id=1)
2010"
2011 );
2012 }
2013}