1use std::{
2 fmt::{Debug, Display},
3 sync::Arc,
4};
5
6use auto_ops::*;
7use dyn_clone::DynClone;
8use nalgebra::{ComplexField, DVector};
9use num::Complex;
10
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use crate::{
17 data::{Dataset, Event},
18 resources::{Cache, Parameters, Resources},
19 Float, LadduError, ParameterID, ReadWrite,
20};
21
22#[cfg(feature = "mpi")]
23use crate::mpi::LadduMPI;
24
25#[cfg(feature = "mpi")]
26use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
27
28#[derive(Clone, Default, Serialize, Deserialize)]
30pub enum ParameterLike {
31 Parameter(String),
33 Constant(Float),
35 #[default]
38 Uninit,
39}
40
41pub fn parameter(name: &str) -> ParameterLike {
43 ParameterLike::Parameter(name.to_string())
44}
45
46pub fn constant(value: Float) -> ParameterLike {
48 ParameterLike::Constant(value)
49}
50
51#[typetag::serde(tag = "type")]
59pub trait Amplitude: DynClone + Send + Sync {
60 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
65 #[allow(unused_variables)]
70 fn precompute(&self, event: &Event, cache: &mut Cache) {}
71 #[cfg(feature = "rayon")]
73 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
74 dataset
75 .events
76 .par_iter()
77 .zip(resources.caches.par_iter_mut())
78 .for_each(|(event, cache)| {
79 self.precompute(event, cache);
80 })
81 }
82 #[cfg(not(feature = "rayon"))]
84 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
85 dataset
86 .events
87 .iter()
88 .zip(resources.caches.iter_mut())
89 .for_each(|(event, cache)| self.precompute(event, cache))
90 }
91 fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
100
101 fn compute_gradient(
117 &self,
118 parameters: &Parameters,
119 event: &Event,
120 cache: &Cache,
121 gradient: &mut DVector<Complex<Float>>,
122 ) {
123 self.central_difference_with_indices(
124 &Vec::from_iter(0..parameters.len()),
125 parameters,
126 event,
127 cache,
128 gradient,
129 )
130 }
131
132 fn central_difference_with_indices(
138 &self,
139 indices: &[usize],
140 parameters: &Parameters,
141 event: &Event,
142 cache: &Cache,
143 gradient: &mut DVector<Complex<Float>>,
144 ) {
145 let x = parameters.parameters.to_owned();
146 let constants = parameters.constants.to_owned();
147 let h: DVector<Float> = x
148 .iter()
149 .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
150 .collect::<Vec<_>>()
151 .into();
152 for i in indices {
153 let mut x_plus = x.clone();
154 let mut x_minus = x.clone();
155 x_plus[*i] += h[*i];
156 x_minus[*i] -= h[*i];
157 let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
158 let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
159 gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
160 }
161 }
162}
163
164pub fn central_difference<F: Fn(&[Float]) -> Float>(
166 parameters: &[Float],
167 func: F,
168) -> DVector<Float> {
169 let mut gradient = DVector::zeros(parameters.len());
170 let x = parameters.to_owned();
171 let h: DVector<Float> = x
172 .iter()
173 .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
174 .collect::<Vec<_>>()
175 .into();
176 for i in 0..parameters.len() {
177 let mut x_plus = x.clone();
178 let mut x_minus = x.clone();
179 x_plus[i] += h[i];
180 x_minus[i] -= h[i];
181 let f_plus = func(&x_plus);
182 let f_minus = func(&x_minus);
183 gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
184 }
185 gradient
186}
187
188dyn_clone::clone_trait_object!(Amplitude);
189
190#[derive(Debug)]
192pub struct AmplitudeValues(pub Vec<Complex<Float>>);
193
194#[derive(Debug)]
196pub struct GradientValues(pub usize, pub Vec<DVector<Complex<Float>>>);
197
198#[derive(Clone, Default, Debug, Serialize, Deserialize)]
201pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
202
203impl Display for AmplitudeID {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 write!(f, "{}(id={})", self.0, self.1)
206 }
207}
208
209impl From<AmplitudeID> for Expression {
210 fn from(value: AmplitudeID) -> Self {
211 Self::Amp(value)
212 }
213}
214
215#[derive(Clone, Serialize, Deserialize, Default)]
217pub enum Expression {
218 #[default]
219 Zero,
221 One,
223 Amp(AmplitudeID),
225 Add(Box<Expression>, Box<Expression>),
227 Sub(Box<Expression>, Box<Expression>),
229 Mul(Box<Expression>, Box<Expression>),
231 Div(Box<Expression>, Box<Expression>),
233 Neg(Box<Expression>),
235 Real(Box<Expression>),
237 Imag(Box<Expression>),
239 Conj(Box<Expression>),
241 NormSqr(Box<Expression>),
243}
244
245impl Debug for Expression {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 self.write_tree(f, "", "", "")
248 }
249}
250
251impl Display for Expression {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 write!(f, "{:?}", self)
254 }
255}
256
257#[rustfmt::skip]
258impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
259 Expression::Add(Box::new(a.clone()), Box::new(b.clone()))
260});
261#[rustfmt::skip]
262impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
263 Expression::Sub(Box::new(a.clone()), Box::new(b.clone()))
264});
265#[rustfmt::skip]
266impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
267 Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
268});
269#[rustfmt::skip]
270impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
271 Expression::Div(Box::new(a.clone()), Box::new(b.clone()))
272});
273#[rustfmt::skip]
274impl_op_ex!(- |a: &Expression| -> Expression {
275 Expression::Neg(Box::new(a.clone()))
276});
277
278#[rustfmt::skip]
279impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression {
280 Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
281});
282#[rustfmt::skip]
283impl_op_ex_commutative!(- |a: &AmplitudeID, b: &Expression| -> Expression {
284 Expression::Sub(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
285});
286#[rustfmt::skip]
287impl_op_ex_commutative!(* |a: &AmplitudeID, b: &Expression| -> Expression {
288 Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
289});
290#[rustfmt::skip]
291impl_op_ex_commutative!(/ |a: &AmplitudeID, b: &Expression| -> Expression {
292 Expression::Div(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
293});
294
295#[rustfmt::skip]
296impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
297 Expression::Add(
298 Box::new(Expression::Amp(a.clone())),
299 Box::new(Expression::Amp(b.clone()))
300 )
301});
302#[rustfmt::skip]
303impl_op_ex!(- |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
304 Expression::Sub(
305 Box::new(Expression::Amp(a.clone())),
306 Box::new(Expression::Amp(b.clone()))
307 )
308});
309#[rustfmt::skip]
310impl_op_ex!(* |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
311 Expression::Mul(
312 Box::new(Expression::Amp(a.clone())),
313 Box::new(Expression::Amp(b.clone())),
314 )
315});
316#[rustfmt::skip]
317impl_op_ex!(/ |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
318 Expression::Div(
319 Box::new(Expression::Amp(a.clone())),
320 Box::new(Expression::Amp(b.clone())),
321 )
322});
323#[rustfmt::skip]
324impl_op_ex!(- |a: &AmplitudeID| -> Expression {
325 Expression::Neg(
326 Box::new(Expression::Amp(a.clone())),
327 )
328});
329
330impl AmplitudeID {
331 pub fn real(&self) -> Expression {
333 Expression::Real(Box::new(Expression::Amp(self.clone())))
334 }
335 pub fn imag(&self) -> Expression {
337 Expression::Imag(Box::new(Expression::Amp(self.clone())))
338 }
339 pub fn conj(&self) -> Expression {
341 Expression::Conj(Box::new(Expression::Amp(self.clone())))
342 }
343 pub fn norm_sqr(&self) -> Expression {
345 Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
346 }
347}
348
349impl Expression {
350 pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
355 match self {
356 Expression::Amp(aid) => amplitude_values.0[aid.1],
357 Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
358 Expression::Sub(a, b) => a.evaluate(amplitude_values) - b.evaluate(amplitude_values),
359 Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
360 Expression::Div(a, b) => a.evaluate(amplitude_values) / b.evaluate(amplitude_values),
361 Expression::Neg(a) => -a.evaluate(amplitude_values),
362 Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
363 Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
364 Expression::Conj(a) => a.evaluate(amplitude_values).conj(),
365 Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
366 Expression::Zero => Complex::ZERO,
367 Expression::One => Complex::ONE,
368 }
369 }
370 pub fn evaluate_gradient(
375 &self,
376 amplitude_values: &AmplitudeValues,
377 gradient_values: &GradientValues,
378 ) -> DVector<Complex<Float>> {
379 match self {
380 Expression::Amp(aid) => gradient_values.1[aid.1].clone(),
381 Expression::Add(a, b) => {
382 a.evaluate_gradient(amplitude_values, gradient_values)
383 + b.evaluate_gradient(amplitude_values, gradient_values)
384 }
385 Expression::Sub(a, b) => {
386 a.evaluate_gradient(amplitude_values, gradient_values)
387 - b.evaluate_gradient(amplitude_values, gradient_values)
388 }
389 Expression::Mul(a, b) => {
390 let f_a = a.evaluate(amplitude_values);
391 let f_b = b.evaluate(amplitude_values);
392 b.evaluate_gradient(amplitude_values, gradient_values)
393 .map(|g| g * f_a)
394 + a.evaluate_gradient(amplitude_values, gradient_values)
395 .map(|g| g * f_b)
396 }
397 Expression::Div(a, b) => {
398 let f_a = a.evaluate(amplitude_values);
399 let f_b = b.evaluate(amplitude_values);
400 (a.evaluate_gradient(amplitude_values, gradient_values)
401 .map(|g| g * f_b)
402 - b.evaluate_gradient(amplitude_values, gradient_values)
403 .map(|g| g * f_a))
404 / (f_b * f_b)
405 }
406 Expression::Neg(a) => -a.evaluate_gradient(amplitude_values, gradient_values),
407 Expression::Real(a) => a
408 .evaluate_gradient(amplitude_values, gradient_values)
409 .map(|g| Complex::new(g.re, 0.0)),
410 Expression::Imag(a) => a
411 .evaluate_gradient(amplitude_values, gradient_values)
412 .map(|g| Complex::new(g.im, 0.0)),
413 Expression::Conj(a) => a
414 .evaluate_gradient(amplitude_values, gradient_values)
415 .map(|g| g.conj()),
416 Expression::NormSqr(a) => {
417 let conj_f_a = a.evaluate(amplitude_values).conjugate();
418 a.evaluate_gradient(amplitude_values, gradient_values)
419 .map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
420 }
421 Expression::Zero | Expression::One => DVector::zeros(gradient_values.0),
422 }
423 }
424 pub fn real(&self) -> Self {
426 Self::Real(Box::new(self.clone()))
427 }
428 pub fn imag(&self) -> Self {
430 Self::Imag(Box::new(self.clone()))
431 }
432 pub fn conj(&self) -> Self {
434 Self::Conj(Box::new(self.clone()))
435 }
436 pub fn norm_sqr(&self) -> Self {
438 Self::NormSqr(Box::new(self.clone()))
439 }
440
441 fn write_tree(
443 &self,
444 f: &mut std::fmt::Formatter<'_>,
445 parent_prefix: &str,
446 immediate_prefix: &str,
447 parent_suffix: &str,
448 ) -> std::fmt::Result {
449 let display_string = match self {
450 Self::Amp(aid) => aid.to_string(),
451 Self::Add(_, _) => "+".to_string(),
452 Self::Sub(_, _) => "-".to_string(),
453 Self::Mul(_, _) => "×".to_string(),
454 Self::Div(_, _) => "÷".to_string(),
455 Self::Neg(_) => "-".to_string(),
456 Self::Real(_) => "Re".to_string(),
457 Self::Imag(_) => "Im".to_string(),
458 Self::Conj(_) => "*".to_string(),
459 Self::NormSqr(_) => "NormSqr".to_string(),
460 Self::Zero => "0".to_string(),
461 Self::One => "1".to_string(),
462 };
463 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
464 match self {
465 Self::Amp(_) | Self::Zero | Self::One => {}
466 Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) | Self::Div(a, b) => {
467 let terms = [a, b];
468 let mut it = terms.iter().peekable();
469 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
470 while let Some(child) = it.next() {
471 match it.peek() {
472 Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│ "),
473 None => child.write_tree(f, &child_prefix, "└─ ", " "),
474 }?;
475 }
476 }
477 Self::Neg(a) | Self::Real(a) | Self::Imag(a) | Self::Conj(a) | Self::NormSqr(a) => {
478 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
479 a.write_tree(f, &child_prefix, "└─ ", " ")?;
480 }
481 }
482 Ok(())
483 }
484}
485
486#[derive(Default, Clone, Serialize, Deserialize)]
489pub struct Manager {
490 amplitudes: Vec<Box<dyn Amplitude>>,
491 resources: Resources,
492}
493
494impl Manager {
495 pub fn parameters(&self) -> Vec<String> {
497 self.resources.parameters.iter().cloned().collect()
498 }
499 pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
507 let mut amp = amplitude.clone();
508 let aid = amp.register(&mut self.resources)?;
509 self.amplitudes.push(amp);
510 Ok(aid)
511 }
512 pub fn model(&self, expression: &Expression) -> Model {
514 Model {
515 manager: self.clone(),
516 expression: expression.clone(),
517 }
518 }
519}
520
521#[derive(Clone, Serialize, Deserialize)]
527pub struct Model {
528 pub(crate) manager: Manager,
529 pub(crate) expression: Expression,
530}
531
532impl ReadWrite for Model {
533 fn create_null() -> Self {
534 Model {
535 manager: Manager::default(),
536 expression: Expression::default(),
537 }
538 }
539}
540impl Model {
541 pub fn parameters(&self) -> Vec<String> {
543 self.manager.parameters()
544 }
545 pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
549 let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
550 loaded_resources.write().reserve_cache(dataset.n_events());
551 for amplitude in &self.manager.amplitudes {
552 amplitude.precompute_all(dataset, &mut loaded_resources.write());
553 }
554 Evaluator {
555 amplitudes: self.manager.amplitudes.clone(),
556 resources: loaded_resources.clone(),
557 dataset: dataset.clone(),
558 expression: self.expression.clone(),
559 }
560 }
561}
562
563#[derive(Clone)]
567pub struct Evaluator {
568 pub amplitudes: Vec<Box<dyn Amplitude>>,
572 pub resources: Arc<RwLock<Resources>>,
574 pub dataset: Arc<Dataset>,
576 pub expression: Expression,
578}
579
580impl Evaluator {
581 pub fn parameters(&self) -> Vec<String> {
584 self.resources.read().parameters.iter().cloned().collect()
585 }
586 pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
588 self.resources.write().activate(name)
589 }
590 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
592 self.resources.write().activate_many(names)
593 }
594 pub fn activate_all(&self) {
596 self.resources.write().activate_all();
597 }
598 pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
600 self.resources.write().deactivate(name)
601 }
602 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
604 self.resources.write().deactivate_many(names)
605 }
606 pub fn deactivate_all(&self) {
608 self.resources.write().deactivate_all();
609 }
610 pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
612 self.resources.write().isolate(name)
613 }
614 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
616 self.resources.write().isolate_many(names)
617 }
618
619 pub fn evaluate_local(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
627 let resources = self.resources.read();
628 let parameters = Parameters::new(parameters, &resources.constants);
629 #[cfg(feature = "rayon")]
630 {
631 let amplitude_values_vec: Vec<AmplitudeValues> = self
632 .dataset
633 .events
634 .par_iter()
635 .zip(resources.caches.par_iter())
636 .map(|(event, cache)| {
637 AmplitudeValues(
638 self.amplitudes
639 .iter()
640 .zip(resources.active.iter())
641 .map(|(amp, active)| {
642 if *active {
643 amp.compute(¶meters, event, cache)
644 } else {
645 Complex::new(0.0, 0.0)
646 }
647 })
648 .collect(),
649 )
650 })
651 .collect();
652 amplitude_values_vec
653 .par_iter()
654 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
655 .collect()
656 }
657 #[cfg(not(feature = "rayon"))]
658 {
659 let amplitude_values_vec: Vec<AmplitudeValues> = self
660 .dataset
661 .events
662 .iter()
663 .zip(resources.caches.iter())
664 .map(|(event, cache)| {
665 AmplitudeValues(
666 self.amplitudes
667 .iter()
668 .zip(resources.active.iter())
669 .map(|(amp, active)| {
670 if *active {
671 amp.compute(¶meters, event, cache)
672 } else {
673 Complex::new(0.0, 0.0)
674 }
675 })
676 .collect(),
677 )
678 })
679 .collect();
680 amplitude_values_vec
681 .iter()
682 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
683 .collect()
684 }
685 }
686
687 #[cfg(feature = "mpi")]
695 fn evaluate_mpi(
696 &self,
697 parameters: &[Float],
698 world: &SimpleCommunicator,
699 ) -> Vec<Complex<Float>> {
700 let local_evaluation = self.evaluate_local(parameters);
701 let n_events = self.dataset.n_events();
702 let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; n_events];
703 let (counts, displs) = world.get_counts_displs(n_events);
704 {
705 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
706 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
707 }
708 buffer
709 }
710
711 pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
714 #[cfg(feature = "mpi")]
715 {
716 if let Some(world) = crate::mpi::get_world() {
717 return self.evaluate_mpi(parameters, &world);
718 }
719 }
720 self.evaluate_local(parameters)
721 }
722
723 pub fn evaluate_batch_local(
726 &self,
727 parameters: &[Float],
728 indices: &[usize],
729 ) -> Vec<Complex<Float>> {
730 let resources = self.resources.read();
731 let parameters = Parameters::new(parameters, &resources.constants);
732 #[cfg(feature = "rayon")]
733 {
734 let amplitude_values_vec: Vec<AmplitudeValues> = self
735 .dataset
736 .events
737 .par_iter()
738 .zip(resources.caches.par_iter())
739 .enumerate()
740 .filter_map(|(i, (event, cache))| {
741 if indices.contains(&i) {
742 Some((event, cache))
743 } else {
744 None
745 }
746 })
747 .map(|(event, cache)| {
748 AmplitudeValues(
749 self.amplitudes
750 .iter()
751 .zip(resources.active.iter())
752 .map(|(amp, active)| {
753 if *active {
754 amp.compute(¶meters, event, cache)
755 } else {
756 Complex::new(0.0, 0.0)
757 }
758 })
759 .collect(),
760 )
761 })
762 .collect();
763 amplitude_values_vec
764 .par_iter()
765 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
766 .collect()
767 }
768 #[cfg(not(feature = "rayon"))]
769 {
770 let amplitude_values_vec: Vec<AmplitudeValues> = self
771 .dataset
772 .events
773 .iter()
774 .zip(resources.caches.iter())
775 .enumerate()
776 .filter_map(|(i, (event, cache))| {
777 if indices.contains(&i) {
778 Some((event, cache))
779 } else {
780 None
781 }
782 })
783 .map(|(event, cache)| {
784 AmplitudeValues(
785 self.amplitudes
786 .iter()
787 .zip(resources.active.iter())
788 .map(|(amp, active)| {
789 if *active {
790 amp.compute(¶meters, event, cache)
791 } else {
792 Complex::new(0.0, 0.0)
793 }
794 })
795 .collect(),
796 )
797 })
798 .collect();
799 amplitude_values_vec
800 .iter()
801 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
802 .collect()
803 }
804 }
805
806 #[cfg(feature = "mpi")]
809 fn evaluate_batch_mpi(
810 &self,
811 parameters: &[Float],
812 indices: &[usize],
813 world: &SimpleCommunicator,
814 ) -> Vec<Complex<Float>> {
815 let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; indices.len()];
816 let (counts, displs, locals) = self
817 .dataset
818 .get_counts_displs_locals_from_indices(indices, world);
819 let local_evaluation = self.evaluate_batch_local(parameters, &locals);
820 {
821 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
822 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
823 }
824 buffer
825 }
826
827 pub fn evaluate_batch(&self, parameters: &[Float], indices: &[usize]) -> Vec<Complex<Float>> {
830 #[cfg(feature = "mpi")]
831 {
832 if let Some(world) = crate::mpi::get_world() {
833 return self.evaluate_batch_mpi(parameters, indices, &world);
834 }
835 }
836 self.evaluate_batch_local(parameters, indices)
837 }
838
839 pub fn evaluate_gradient_local(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
847 let resources = self.resources.read();
848 let parameters = Parameters::new(parameters, &resources.constants);
849 #[cfg(feature = "rayon")]
850 {
851 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
852 .dataset
853 .events
854 .par_iter()
855 .zip(resources.caches.par_iter())
856 .map(|(event, cache)| {
857 let mut gradient_values =
858 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
859 self.amplitudes
860 .iter()
861 .zip(resources.active.iter())
862 .zip(gradient_values.iter_mut())
863 .for_each(|((amp, active), grad)| {
864 if *active {
865 amp.compute_gradient(¶meters, event, cache, grad)
866 }
867 });
868 (
869 AmplitudeValues(
870 self.amplitudes
871 .iter()
872 .zip(resources.active.iter())
873 .map(|(amp, active)| {
874 if *active {
875 amp.compute(¶meters, event, cache)
876 } else {
877 Complex::new(0.0, 0.0)
878 }
879 })
880 .collect(),
881 ),
882 GradientValues(parameters.len(), gradient_values),
883 )
884 })
885 .collect();
886 amplitude_values_and_gradient_vec
887 .par_iter()
888 .map(|(amplitude_values, gradient_values)| {
889 self.expression
890 .evaluate_gradient(amplitude_values, gradient_values)
891 })
892 .collect()
893 }
894 #[cfg(not(feature = "rayon"))]
895 {
896 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
897 .dataset
898 .events
899 .iter()
900 .zip(resources.caches.iter())
901 .map(|(event, cache)| {
902 let mut gradient_values =
903 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
904 self.amplitudes
905 .iter()
906 .zip(resources.active.iter())
907 .zip(gradient_values.iter_mut())
908 .for_each(|((amp, active), grad)| {
909 if *active {
910 amp.compute_gradient(¶meters, event, cache, grad)
911 }
912 });
913 (
914 AmplitudeValues(
915 self.amplitudes
916 .iter()
917 .zip(resources.active.iter())
918 .map(|(amp, active)| {
919 if *active {
920 amp.compute(¶meters, event, cache)
921 } else {
922 Complex::new(0.0, 0.0)
923 }
924 })
925 .collect(),
926 ),
927 GradientValues(parameters.len(), gradient_values),
928 )
929 })
930 .collect();
931
932 amplitude_values_and_gradient_vec
933 .iter()
934 .map(|(amplitude_values, gradient_values)| {
935 self.expression
936 .evaluate_gradient(amplitude_values, gradient_values)
937 })
938 .collect()
939 }
940 }
941
942 #[cfg(feature = "mpi")]
950 fn evaluate_gradient_mpi(
951 &self,
952 parameters: &[Float],
953 world: &SimpleCommunicator,
954 ) -> Vec<DVector<Complex<Float>>> {
955 let flattened_local_evaluation = self
956 .evaluate_gradient_local(parameters)
957 .iter()
958 .flat_map(|g| g.data.as_vec().to_vec())
959 .collect::<Vec<Complex<Float>>>();
960 let n_events = self.dataset.n_events();
961 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
962 let mut flattened_result_buffer = vec![Complex::ZERO; n_events * parameters.len()];
963 let mut partitioned_flattened_result_buffer =
964 PartitionMut::new(&mut flattened_result_buffer, counts, displs);
965 world.all_gather_varcount_into(
966 &flattened_local_evaluation,
967 &mut partitioned_flattened_result_buffer,
968 );
969 flattened_result_buffer
970 .chunks(parameters.len())
971 .map(DVector::from_row_slice)
972 .collect()
973 }
974
975 pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
978 #[cfg(feature = "mpi")]
979 {
980 if let Some(world) = crate::mpi::get_world() {
981 return self.evaluate_gradient_mpi(parameters, &world);
982 }
983 }
984 self.evaluate_gradient_local(parameters)
985 }
986
987 pub fn evaluate_gradient_batch_local(
990 &self,
991 parameters: &[Float],
992 indices: &[usize],
993 ) -> Vec<DVector<Complex<Float>>> {
994 let resources = self.resources.read();
995 let parameters = Parameters::new(parameters, &resources.constants);
996 #[cfg(feature = "rayon")]
997 {
998 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
999 .dataset
1000 .events
1001 .par_iter()
1002 .zip(resources.caches.par_iter())
1003 .enumerate()
1004 .filter_map(|(i, (event, cache))| {
1005 if indices.contains(&i) {
1006 Some((event, cache))
1007 } else {
1008 None
1009 }
1010 })
1011 .map(|(event, cache)| {
1012 let mut gradient_values =
1013 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1014 self.amplitudes
1015 .iter()
1016 .zip(resources.active.iter())
1017 .zip(gradient_values.iter_mut())
1018 .for_each(|((amp, active), grad)| {
1019 if *active {
1020 amp.compute_gradient(¶meters, event, cache, grad)
1021 }
1022 });
1023 (
1024 AmplitudeValues(
1025 self.amplitudes
1026 .iter()
1027 .zip(resources.active.iter())
1028 .map(|(amp, active)| {
1029 if *active {
1030 amp.compute(¶meters, event, cache)
1031 } else {
1032 Complex::new(0.0, 0.0)
1033 }
1034 })
1035 .collect(),
1036 ),
1037 GradientValues(parameters.len(), gradient_values),
1038 )
1039 })
1040 .collect();
1041 amplitude_values_and_gradient_vec
1042 .par_iter()
1043 .map(|(amplitude_values, gradient_values)| {
1044 self.expression
1045 .evaluate_gradient(amplitude_values, gradient_values)
1046 })
1047 .collect()
1048 }
1049 #[cfg(not(feature = "rayon"))]
1050 {
1051 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
1052 .dataset
1053 .events
1054 .iter()
1055 .zip(resources.caches.iter())
1056 .enumerate()
1057 .filter_map(|(i, (event, cache))| {
1058 if indices.contains(&i) {
1059 Some((event, cache))
1060 } else {
1061 None
1062 }
1063 })
1064 .map(|(event, cache)| {
1065 let mut gradient_values =
1066 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
1067 self.amplitudes
1068 .iter()
1069 .zip(resources.active.iter())
1070 .zip(gradient_values.iter_mut())
1071 .for_each(|((amp, active), grad)| {
1072 if *active {
1073 amp.compute_gradient(¶meters, event, cache, grad)
1074 }
1075 });
1076 (
1077 AmplitudeValues(
1078 self.amplitudes
1079 .iter()
1080 .zip(resources.active.iter())
1081 .map(|(amp, active)| {
1082 if *active {
1083 amp.compute(¶meters, event, cache)
1084 } else {
1085 Complex::new(0.0, 0.0)
1086 }
1087 })
1088 .collect(),
1089 ),
1090 GradientValues(parameters.len(), gradient_values),
1091 )
1092 })
1093 .collect();
1094
1095 amplitude_values_and_gradient_vec
1096 .iter()
1097 .map(|(amplitude_values, gradient_values)| {
1098 self.expression
1099 .evaluate_gradient(amplitude_values, gradient_values)
1100 })
1101 .collect()
1102 }
1103 }
1104
1105 #[cfg(feature = "mpi")]
1108 fn evaluate_gradient_batch_mpi(
1109 &self,
1110 parameters: &[Float],
1111 indices: &[usize],
1112 world: &SimpleCommunicator,
1113 ) -> Vec<DVector<Complex<Float>>> {
1114 let (counts, displs, locals) = self
1115 .dataset
1116 .get_flattened_counts_displs_locals_from_indices(indices, parameters.len(), world);
1117 let flattened_local_evaluation = self
1118 .evaluate_gradient_batch_local(parameters, &locals)
1119 .iter()
1120 .flat_map(|g| g.data.as_vec().to_vec())
1121 .collect::<Vec<Complex<Float>>>();
1122 let mut flattened_result_buffer = vec![Complex::ZERO; indices.len() * parameters.len()];
1123 let mut partitioned_flattened_result_buffer =
1124 PartitionMut::new(&mut flattened_result_buffer, counts, displs);
1125 world.all_gather_varcount_into(
1126 &flattened_local_evaluation,
1127 &mut partitioned_flattened_result_buffer,
1128 );
1129 flattened_result_buffer
1130 .chunks(parameters.len())
1131 .map(DVector::from_row_slice)
1132 .collect()
1133 }
1134
1135 pub fn evaluate_gradient_batch(
1139 &self,
1140 parameters: &[Float],
1141 indices: &[usize],
1142 ) -> Vec<DVector<Complex<Float>>> {
1143 #[cfg(feature = "mpi")]
1144 {
1145 if let Some(world) = crate::mpi::get_world() {
1146 return self.evaluate_gradient_batch_mpi(parameters, indices, &world);
1147 }
1148 }
1149 self.evaluate_gradient_batch_local(parameters, indices)
1150 }
1151}
1152
1153#[derive(Clone, Serialize, Deserialize)]
1155pub struct TestAmplitude {
1156 name: String,
1157 re: ParameterLike,
1158 pid_re: ParameterID,
1159 im: ParameterLike,
1160 pid_im: ParameterID,
1161}
1162
1163impl TestAmplitude {
1164 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
1166 Self {
1167 name: name.to_string(),
1168 re,
1169 pid_re: Default::default(),
1170 im,
1171 pid_im: Default::default(),
1172 }
1173 .into()
1174 }
1175}
1176
1177#[typetag::serde]
1178impl Amplitude for TestAmplitude {
1179 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
1180 self.pid_re = resources.register_parameter(&self.re);
1181 self.pid_im = resources.register_parameter(&self.im);
1182 resources.register_amplitude(&self.name)
1183 }
1184
1185 fn compute(&self, parameters: &Parameters, event: &Event, _cache: &Cache) -> Complex<Float> {
1186 Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im)) * event.p4s[0].e()
1187 }
1188
1189 fn compute_gradient(
1190 &self,
1191 _parameters: &Parameters,
1192 event: &Event,
1193 _cache: &Cache,
1194 gradient: &mut DVector<Complex<Float>>,
1195 ) {
1196 if let ParameterID::Parameter(ind) = self.pid_re {
1197 gradient[ind] = Complex::ONE * event.p4s[0].e();
1198 }
1199 if let ParameterID::Parameter(ind) = self.pid_im {
1200 gradient[ind] = Complex::I * event.p4s[0].e();
1201 }
1202 }
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207 use crate::data::{test_dataset, test_event};
1208
1209 use super::*;
1210 use crate::{
1211 data::Event,
1212 resources::{Cache, ParameterID, Parameters, Resources},
1213 Float, LadduError,
1214 };
1215 use approx::assert_relative_eq;
1216 use serde::{Deserialize, Serialize};
1217
1218 #[derive(Clone, Serialize, Deserialize)]
1219 pub struct ComplexScalar {
1220 name: String,
1221 re: ParameterLike,
1222 pid_re: ParameterID,
1223 im: ParameterLike,
1224 pid_im: ParameterID,
1225 }
1226
1227 impl ComplexScalar {
1228 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
1229 Self {
1230 name: name.to_string(),
1231 re,
1232 pid_re: Default::default(),
1233 im,
1234 pid_im: Default::default(),
1235 }
1236 .into()
1237 }
1238 }
1239
1240 #[typetag::serde]
1241 impl Amplitude for ComplexScalar {
1242 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
1243 self.pid_re = resources.register_parameter(&self.re);
1244 self.pid_im = resources.register_parameter(&self.im);
1245 resources.register_amplitude(&self.name)
1246 }
1247
1248 fn compute(
1249 &self,
1250 parameters: &Parameters,
1251 _event: &Event,
1252 _cache: &Cache,
1253 ) -> Complex<Float> {
1254 Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
1255 }
1256
1257 fn compute_gradient(
1258 &self,
1259 _parameters: &Parameters,
1260 _event: &Event,
1261 _cache: &Cache,
1262 gradient: &mut DVector<Complex<Float>>,
1263 ) {
1264 if let ParameterID::Parameter(ind) = self.pid_re {
1265 gradient[ind] = Complex::ONE;
1266 }
1267 if let ParameterID::Parameter(ind) = self.pid_im {
1268 gradient[ind] = Complex::I;
1269 }
1270 }
1271 }
1272
1273 #[test]
1274 fn test_batch_evaluation() {
1275 let mut manager = Manager::default();
1276 let amp = TestAmplitude::new("test", parameter("real"), parameter("imag"));
1277 let aid = manager.register(amp).unwrap();
1278 let mut event1 = test_event();
1279 event1.p4s[0].t = 10.0;
1280 let mut event2 = test_event();
1281 event2.p4s[0].t = 11.0;
1282 let mut event3 = test_event();
1283 event3.p4s[0].t = 12.0;
1284 let dataset = Arc::new(Dataset {
1285 events: vec![Arc::new(event1), Arc::new(event2), Arc::new(event3)],
1286 });
1287 let expr = Expression::Amp(aid);
1288 let model = manager.model(&expr);
1289 let evaluator = model.load(&dataset);
1290 let result = evaluator.evaluate_batch(&[1.1, 2.2], &[0, 2]);
1291 assert_eq!(result.len(), 2);
1292 assert_eq!(result[0], Complex::new(1.1, 2.2) * 10.0);
1293 assert_eq!(result[1], Complex::new(1.1, 2.2) * 12.0);
1294 let result_grad = evaluator.evaluate_gradient_batch(&[1.1, 2.2], &[0, 2]);
1295 assert_eq!(result_grad.len(), 2);
1296 assert_eq!(result_grad[0][0], Complex::new(10.0, 0.0));
1297 assert_eq!(result_grad[0][1], Complex::new(0.0, 10.0));
1298 assert_eq!(result_grad[1][0], Complex::new(12.0, 0.0));
1299 assert_eq!(result_grad[1][1], Complex::new(0.0, 12.0));
1300 }
1301
1302 #[test]
1303 fn test_constant_amplitude() {
1304 let mut manager = Manager::default();
1305 let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
1306 let aid = manager.register(amp).unwrap();
1307 let dataset = Arc::new(Dataset {
1308 events: vec![Arc::new(test_event())],
1309 });
1310 let expr = Expression::Amp(aid);
1311 let model = manager.model(&expr);
1312 let evaluator = model.load(&dataset);
1313 let result = evaluator.evaluate(&[]);
1314 assert_eq!(result[0], Complex::new(2.0, 3.0));
1315 }
1316
1317 #[test]
1318 fn test_parametric_amplitude() {
1319 let mut manager = Manager::default();
1320 let amp = ComplexScalar::new(
1321 "parametric",
1322 parameter("test_param_re"),
1323 parameter("test_param_im"),
1324 );
1325 let aid = manager.register(amp).unwrap();
1326 let dataset = Arc::new(test_dataset());
1327 let expr = Expression::Amp(aid);
1328 let model = manager.model(&expr);
1329 let evaluator = model.load(&dataset);
1330 let result = evaluator.evaluate(&[2.0, 3.0]);
1331 assert_eq!(result[0], Complex::new(2.0, 3.0));
1332 }
1333
1334 #[test]
1335 fn test_expression_operations() {
1336 let mut manager = Manager::default();
1337 let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
1338 let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
1339 let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
1340
1341 let aid1 = manager.register(amp1).unwrap();
1342 let aid2 = manager.register(amp2).unwrap();
1343 let aid3 = manager.register(amp3).unwrap();
1344
1345 let dataset = Arc::new(test_dataset());
1346
1347 let expr_add = &aid1 + &aid2;
1349 let model_add = manager.model(&expr_add);
1350 let eval_add = model_add.load(&dataset);
1351 let result_add = eval_add.evaluate(&[]);
1352 assert_eq!(result_add[0], Complex::new(2.0, 1.0));
1353
1354 let expr_sub = &aid1 - &aid2;
1356 let model_sub = manager.model(&expr_sub);
1357 let eval_sub = model_sub.load(&dataset);
1358 let result_sub = eval_sub.evaluate(&[]);
1359 assert_eq!(result_sub[0], Complex::new(2.0, -1.0));
1360
1361 let expr_mul = &aid1 * &aid2;
1363 let model_mul = manager.model(&expr_mul);
1364 let eval_mul = model_mul.load(&dataset);
1365 let result_mul = eval_mul.evaluate(&[]);
1366 assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
1367
1368 let expr_div = &aid1 / &aid3;
1370 let model_div = manager.model(&expr_div);
1371 let eval_div = model_div.load(&dataset);
1372 let result_div = eval_div.evaluate(&[]);
1373 assert_eq!(result_div[0], Complex::new(6.0 / 25.0, -8.0 / 25.0));
1374
1375 let expr_neg = -&aid3;
1377 let model_neg = manager.model(&expr_neg);
1378 let eval_neg = model_neg.load(&dataset);
1379 let result_neg = eval_neg.evaluate(&[]);
1380 assert_eq!(result_neg[0], Complex::new(-3.0, -4.0));
1381
1382 let expr_add2 = &expr_add + &expr_mul;
1384 let model_add2 = manager.model(&expr_add2);
1385 let eval_add2 = model_add2.load(&dataset);
1386 let result_add2 = eval_add2.evaluate(&[]);
1387 assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
1388
1389 let expr_sub2 = &expr_add - &expr_mul;
1391 let model_sub2 = manager.model(&expr_sub2);
1392 let eval_sub2 = model_sub2.load(&dataset);
1393 let result_sub2 = eval_sub2.evaluate(&[]);
1394 assert_eq!(result_sub2[0], Complex::new(2.0, -1.0));
1395
1396 let expr_mul2 = &expr_add * &expr_mul;
1398 let model_mul2 = manager.model(&expr_mul2);
1399 let eval_mul2 = model_mul2.load(&dataset);
1400 let result_mul2 = eval_mul2.evaluate(&[]);
1401 assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
1402
1403 let expr_div2 = &expr_add / &expr_add2;
1405 let model_div2 = manager.model(&expr_div2);
1406 let eval_div2 = model_div2.load(&dataset);
1407 let result_div2 = eval_div2.evaluate(&[]);
1408 assert_eq!(result_div2[0], Complex::new(7.0 / 13.0, -4.0 / 13.0));
1409
1410 let expr_neg2 = -&expr_mul2;
1412 let model_neg2 = manager.model(&expr_neg2);
1413 let eval_neg2 = model_neg2.load(&dataset);
1414 let result_neg2 = eval_neg2.evaluate(&[]);
1415 assert_eq!(result_neg2[0], Complex::new(2.0, -4.0));
1416
1417 let expr_real = aid3.real();
1419 let model_real = manager.model(&expr_real);
1420 let eval_real = model_real.load(&dataset);
1421 let result_real = eval_real.evaluate(&[]);
1422 assert_eq!(result_real[0], Complex::new(3.0, 0.0));
1423
1424 let expr_mul2_real = expr_mul2.real();
1426 let model_mul2_real = manager.model(&expr_mul2_real);
1427 let eval_mul2_real = model_mul2_real.load(&dataset);
1428 let result_mul2_real = eval_mul2_real.evaluate(&[]);
1429 assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
1430
1431 let expr_imag = aid3.imag();
1433 let model_imag = manager.model(&expr_imag);
1434 let eval_imag = model_imag.load(&dataset);
1435 let result_imag = eval_imag.evaluate(&[]);
1436 assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
1437
1438 let expr_mul2_imag = expr_mul2.imag();
1440 let model_mul2_imag = manager.model(&expr_mul2_imag);
1441 let eval_mul2_imag = model_mul2_imag.load(&dataset);
1442 let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
1443 assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
1444
1445 let expr_conj = aid3.conj();
1447 let model_conj = manager.model(&expr_conj);
1448 let eval_conj = model_conj.load(&dataset);
1449 let result_conj = eval_conj.evaluate(&[]);
1450 assert_eq!(result_conj[0], Complex::new(3.0, -4.0));
1451
1452 let expr_mul2_conj = expr_mul2.conj();
1454 let model_mul2_conj = manager.model(&expr_mul2_conj);
1455 let eval_mul2_conj = model_mul2_conj.load(&dataset);
1456 let result_mul2_conj = eval_mul2_conj.evaluate(&[]);
1457 assert_eq!(result_mul2_conj[0], Complex::new(-2.0, -4.0));
1458
1459 let expr_norm = aid1.norm_sqr();
1461 let model_norm = manager.model(&expr_norm);
1462 let eval_norm = model_norm.load(&dataset);
1463 let result_norm = eval_norm.evaluate(&[]);
1464 assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
1465
1466 let expr_mul2_norm = expr_mul2.norm_sqr();
1468 let model_mul2_norm = manager.model(&expr_mul2_norm);
1469 let eval_mul2_norm = model_mul2_norm.load(&dataset);
1470 let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
1471 assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
1472 }
1473
1474 #[test]
1475 fn test_amplitude_activation() {
1476 let mut manager = Manager::default();
1477 let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
1478 let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
1479
1480 let aid1 = manager.register(amp1).unwrap();
1481 let aid2 = manager.register(amp2).unwrap();
1482
1483 let dataset = Arc::new(test_dataset());
1484 let expr = &aid1 + &aid2;
1485 let model = manager.model(&expr);
1486 let evaluator = model.load(&dataset);
1487
1488 let result = evaluator.evaluate(&[]);
1490 assert_eq!(result[0], Complex::new(3.0, 0.0));
1491
1492 evaluator.deactivate("const1").unwrap();
1494 let result = evaluator.evaluate(&[]);
1495 assert_eq!(result[0], Complex::new(2.0, 0.0));
1496
1497 evaluator.isolate("const1").unwrap();
1499 let result = evaluator.evaluate(&[]);
1500 assert_eq!(result[0], Complex::new(1.0, 0.0));
1501
1502 evaluator.activate_all();
1504 let result = evaluator.evaluate(&[]);
1505 assert_eq!(result[0], Complex::new(3.0, 0.0));
1506 }
1507
1508 #[test]
1509 fn test_gradient() {
1510 let mut manager = Manager::default();
1511 let amp1 = ComplexScalar::new(
1512 "parametric_1",
1513 parameter("test_param_re_1"),
1514 parameter("test_param_im_1"),
1515 );
1516 let amp2 = ComplexScalar::new(
1517 "parametric_2",
1518 parameter("test_param_re_2"),
1519 parameter("test_param_im_2"),
1520 );
1521
1522 let aid1 = manager.register(amp1).unwrap();
1523 let aid2 = manager.register(amp2).unwrap();
1524 let dataset = Arc::new(test_dataset());
1525 let params = vec![2.0, 3.0, 4.0, 5.0];
1526
1527 let expr = &aid1 + &aid2;
1528 let model = manager.model(&expr);
1529 let evaluator = model.load(&dataset);
1530
1531 let gradient = evaluator.evaluate_gradient(¶ms);
1532
1533 assert_relative_eq!(gradient[0][0].re, 1.0);
1534 assert_relative_eq!(gradient[0][0].im, 0.0);
1535 assert_relative_eq!(gradient[0][1].re, 0.0);
1536 assert_relative_eq!(gradient[0][1].im, 1.0);
1537 assert_relative_eq!(gradient[0][2].re, 1.0);
1538 assert_relative_eq!(gradient[0][2].im, 0.0);
1539 assert_relative_eq!(gradient[0][3].re, 0.0);
1540 assert_relative_eq!(gradient[0][3].im, 1.0);
1541
1542 let expr = &aid1 - &aid2;
1543 let model = manager.model(&expr);
1544 let evaluator = model.load(&dataset);
1545
1546 let gradient = evaluator.evaluate_gradient(¶ms);
1547
1548 assert_relative_eq!(gradient[0][0].re, 1.0);
1549 assert_relative_eq!(gradient[0][0].im, 0.0);
1550 assert_relative_eq!(gradient[0][1].re, 0.0);
1551 assert_relative_eq!(gradient[0][1].im, 1.0);
1552 assert_relative_eq!(gradient[0][2].re, -1.0);
1553 assert_relative_eq!(gradient[0][2].im, 0.0);
1554 assert_relative_eq!(gradient[0][3].re, 0.0);
1555 assert_relative_eq!(gradient[0][3].im, -1.0);
1556
1557 let expr = &aid1 * &aid2;
1558 let model = manager.model(&expr);
1559 let evaluator = model.load(&dataset);
1560
1561 let gradient = evaluator.evaluate_gradient(¶ms);
1562
1563 assert_relative_eq!(gradient[0][0].re, 4.0);
1564 assert_relative_eq!(gradient[0][0].im, 5.0);
1565 assert_relative_eq!(gradient[0][1].re, -5.0);
1566 assert_relative_eq!(gradient[0][1].im, 4.0);
1567 assert_relative_eq!(gradient[0][2].re, 2.0);
1568 assert_relative_eq!(gradient[0][2].im, 3.0);
1569 assert_relative_eq!(gradient[0][3].re, -3.0);
1570 assert_relative_eq!(gradient[0][3].im, 2.0);
1571
1572 let expr = &aid1 / &aid2;
1573 let model = manager.model(&expr);
1574 let evaluator = model.load(&dataset);
1575
1576 let gradient = evaluator.evaluate_gradient(¶ms);
1577
1578 assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
1579 assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
1580 assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
1581 assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
1582 assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
1583 assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
1584 assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
1585 assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
1586
1587 let expr = -(&aid1 * &aid2);
1588 let model = manager.model(&expr);
1589 let evaluator = model.load(&dataset);
1590
1591 let gradient = evaluator.evaluate_gradient(¶ms);
1592
1593 assert_relative_eq!(gradient[0][0].re, -4.0);
1594 assert_relative_eq!(gradient[0][0].im, -5.0);
1595 assert_relative_eq!(gradient[0][1].re, 5.0);
1596 assert_relative_eq!(gradient[0][1].im, -4.0);
1597 assert_relative_eq!(gradient[0][2].re, -2.0);
1598 assert_relative_eq!(gradient[0][2].im, -3.0);
1599 assert_relative_eq!(gradient[0][3].re, 3.0);
1600 assert_relative_eq!(gradient[0][3].im, -2.0);
1601
1602 let expr = (&aid1 * &aid2).real();
1603 let model = manager.model(&expr);
1604 let evaluator = model.load(&dataset);
1605
1606 let gradient = evaluator.evaluate_gradient(¶ms);
1607
1608 assert_relative_eq!(gradient[0][0].re, 4.0);
1609 assert_relative_eq!(gradient[0][0].im, 0.0);
1610 assert_relative_eq!(gradient[0][1].re, -5.0);
1611 assert_relative_eq!(gradient[0][1].im, 0.0);
1612 assert_relative_eq!(gradient[0][2].re, 2.0);
1613 assert_relative_eq!(gradient[0][2].im, 0.0);
1614 assert_relative_eq!(gradient[0][3].re, -3.0);
1615 assert_relative_eq!(gradient[0][3].im, 0.0);
1616
1617 let expr = (&aid1 * &aid2).imag();
1618 let model = manager.model(&expr);
1619 let evaluator = model.load(&dataset);
1620
1621 let gradient = evaluator.evaluate_gradient(¶ms);
1622
1623 assert_relative_eq!(gradient[0][0].re, 5.0);
1624 assert_relative_eq!(gradient[0][0].im, 0.0);
1625 assert_relative_eq!(gradient[0][1].re, 4.0);
1626 assert_relative_eq!(gradient[0][1].im, 0.0);
1627 assert_relative_eq!(gradient[0][2].re, 3.0);
1628 assert_relative_eq!(gradient[0][2].im, 0.0);
1629 assert_relative_eq!(gradient[0][3].re, 2.0);
1630 assert_relative_eq!(gradient[0][3].im, 0.0);
1631
1632 let expr = (&aid1 * &aid2).conj();
1633 let model = manager.model(&expr);
1634 let evaluator = model.load(&dataset);
1635
1636 let gradient = evaluator.evaluate_gradient(¶ms);
1637
1638 assert_relative_eq!(gradient[0][0].re, 4.0);
1639 assert_relative_eq!(gradient[0][0].im, -5.0);
1640 assert_relative_eq!(gradient[0][1].re, -5.0);
1641 assert_relative_eq!(gradient[0][1].im, -4.0);
1642 assert_relative_eq!(gradient[0][2].re, 2.0);
1643 assert_relative_eq!(gradient[0][2].im, -3.0);
1644 assert_relative_eq!(gradient[0][3].re, -3.0);
1645 assert_relative_eq!(gradient[0][3].im, -2.0);
1646
1647 let expr = (&aid1 * &aid2).norm_sqr();
1648 let model = manager.model(&expr);
1649 let evaluator = model.load(&dataset);
1650
1651 let gradient = evaluator.evaluate_gradient(¶ms);
1652
1653 assert_relative_eq!(gradient[0][0].re, 164.0);
1654 assert_relative_eq!(gradient[0][0].im, 0.0);
1655 assert_relative_eq!(gradient[0][1].re, 246.0);
1656 assert_relative_eq!(gradient[0][1].im, 0.0);
1657 assert_relative_eq!(gradient[0][2].re, 104.0);
1658 assert_relative_eq!(gradient[0][2].im, 0.0);
1659 assert_relative_eq!(gradient[0][3].re, 130.0);
1660 assert_relative_eq!(gradient[0][3].im, 0.0);
1661 }
1662
1663 #[test]
1664 fn test_zeros_and_ones() {
1665 let mut manager = Manager::default();
1666 let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1667 let aid = manager.register(amp).unwrap();
1668 let dataset = Arc::new(test_dataset());
1669 let expr = (aid * Expression::One + Expression::Zero).norm_sqr();
1670 let model = manager.model(&expr);
1671 let evaluator = model.load(&dataset);
1672
1673 let params = vec![2.0];
1674 let value = evaluator.evaluate(¶ms);
1675 let gradient = evaluator.evaluate_gradient(¶ms);
1676
1677 assert_relative_eq!(value[0].re, 8.0);
1679 assert_relative_eq!(value[0].im, 0.0);
1680
1681 assert_relative_eq!(gradient[0][0].re, 4.0);
1683 assert_relative_eq!(gradient[0][0].im, 0.0);
1684 }
1685
1686 #[test]
1687 fn test_parameter_registration() {
1688 let mut manager = Manager::default();
1689 let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1690
1691 let aid = manager.register(amp).unwrap();
1692 let parameters = manager.parameters();
1693 let model = manager.model(&aid.into());
1694 let model_parameters = model.parameters();
1695 assert_eq!(parameters.len(), 1);
1696 assert_eq!(parameters[0], "test_param_re");
1697 assert_eq!(model_parameters.len(), 1);
1698 assert_eq!(model_parameters[0], "test_param_re");
1699 }
1700
1701 #[test]
1702 fn test_duplicate_amplitude_registration() {
1703 let mut manager = Manager::default();
1704 let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
1705 let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
1706 manager.register(amp1).unwrap();
1707 assert!(manager.register(amp2).is_err());
1708 }
1709
1710 #[test]
1711 fn test_tree_printing() {
1712 let mut manager = Manager::default();
1713 let amp1 = ComplexScalar::new(
1714 "parametric_1",
1715 parameter("test_param_re_1"),
1716 parameter("test_param_im_1"),
1717 );
1718 let amp2 = ComplexScalar::new(
1719 "parametric_2",
1720 parameter("test_param_re_2"),
1721 parameter("test_param_im_2"),
1722 );
1723 let aid1 = manager.register(amp1).unwrap();
1724 let aid2 = manager.register(amp2).unwrap();
1725 let expr = &aid1.real() + &aid2.conj().imag() + Expression::One * -Expression::Zero
1726 - Expression::Zero / Expression::One
1727 + (&aid1 * &aid2).norm_sqr();
1728 assert_eq!(
1729 expr.to_string(),
1730 "+
1731├─ -
1732│ ├─ +
1733│ │ ├─ +
1734│ │ │ ├─ Re
1735│ │ │ │ └─ parametric_1(id=0)
1736│ │ │ └─ Im
1737│ │ │ └─ *
1738│ │ │ └─ parametric_2(id=1)
1739│ │ └─ ×
1740│ │ ├─ 1
1741│ │ └─ -
1742│ │ └─ 0
1743│ └─ ÷
1744│ ├─ 0
1745│ └─ 1
1746└─ NormSqr
1747 └─ ×
1748 ├─ parametric_1(id=0)
1749 └─ parametric_2(id=1)
1750"
1751 );
1752 }
1753}