1use std::{
2 fmt::{Debug, Display},
3 sync::Arc,
4};
5
6use auto_ops::*;
7use dyn_clone::DynClone;
8use nalgebra::{ComplexField, DVector};
9use num::Complex;
10
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use crate::{
17 data::{Dataset, Event},
18 resources::{Cache, Parameters, Resources},
19 Float, LadduError,
20};
21
22#[cfg(feature = "mpi")]
23use crate::mpi::LadduMPI;
24
25#[cfg(feature = "mpi")]
26use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
27
28#[derive(Clone, Default, Serialize, Deserialize)]
30pub enum ParameterLike {
31 Parameter(String),
33 Constant(Float),
35 #[default]
38 Uninit,
39}
40
41pub fn parameter(name: &str) -> ParameterLike {
43 ParameterLike::Parameter(name.to_string())
44}
45
46pub fn constant(value: Float) -> ParameterLike {
48 ParameterLike::Constant(value)
49}
50
51#[typetag::serde(tag = "type")]
59pub trait Amplitude: DynClone + Send + Sync {
60 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
65 #[allow(unused_variables)]
70 fn precompute(&self, event: &Event, cache: &mut Cache) {}
71 #[cfg(feature = "rayon")]
73 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
74 dataset
75 .events
76 .par_iter()
77 .zip(resources.caches.par_iter_mut())
78 .for_each(|(event, cache)| {
79 self.precompute(event, cache);
80 })
81 }
82 #[cfg(not(feature = "rayon"))]
84 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
85 dataset
86 .events
87 .iter()
88 .zip(resources.caches.iter_mut())
89 .for_each(|(event, cache)| self.precompute(event, cache))
90 }
91 fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
100
101 fn compute_gradient(
117 &self,
118 parameters: &Parameters,
119 event: &Event,
120 cache: &Cache,
121 gradient: &mut DVector<Complex<Float>>,
122 ) {
123 self.central_difference_with_indices(
124 &Vec::from_iter(0..parameters.len()),
125 parameters,
126 event,
127 cache,
128 gradient,
129 )
130 }
131
132 fn central_difference_with_indices(
138 &self,
139 indices: &[usize],
140 parameters: &Parameters,
141 event: &Event,
142 cache: &Cache,
143 gradient: &mut DVector<Complex<Float>>,
144 ) {
145 let x = parameters.parameters.to_owned();
146 let constants = parameters.constants.to_owned();
147 let h: DVector<Float> = x
148 .iter()
149 .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
150 .collect::<Vec<_>>()
151 .into();
152 for i in indices {
153 let mut x_plus = x.clone();
154 let mut x_minus = x.clone();
155 x_plus[*i] += h[*i];
156 x_minus[*i] -= h[*i];
157 let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
158 let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
159 gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
160 }
161 }
162}
163
164pub fn central_difference<F: Fn(&[Float]) -> Float>(
166 parameters: &[Float],
167 func: F,
168) -> DVector<Float> {
169 let mut gradient = DVector::zeros(parameters.len());
170 let x = parameters.to_owned();
171 let h: DVector<Float> = x
172 .iter()
173 .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
174 .collect::<Vec<_>>()
175 .into();
176 for i in 0..parameters.len() {
177 let mut x_plus = x.clone();
178 let mut x_minus = x.clone();
179 x_plus[i] += h[i];
180 x_minus[i] -= h[i];
181 let f_plus = func(&x_plus);
182 let f_minus = func(&x_minus);
183 gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
184 }
185 gradient
186}
187
188dyn_clone::clone_trait_object!(Amplitude);
189
190#[derive(Debug)]
192pub struct AmplitudeValues(pub Vec<Complex<Float>>);
193
194#[derive(Debug)]
196pub struct GradientValues(pub usize, pub Vec<DVector<Complex<Float>>>);
197
198#[derive(Clone, Default, Debug, Serialize, Deserialize)]
201pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
202
203impl Display for AmplitudeID {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 write!(f, "{}(id={})", self.0, self.1)
206 }
207}
208
209impl From<AmplitudeID> for Expression {
210 fn from(value: AmplitudeID) -> Self {
211 Self::Amp(value)
212 }
213}
214
215#[derive(Clone, Serialize, Deserialize, Default)]
217pub enum Expression {
218 #[default]
219 Zero,
221 One,
223 Amp(AmplitudeID),
225 Add(Box<Expression>, Box<Expression>),
227 Sub(Box<Expression>, Box<Expression>),
229 Mul(Box<Expression>, Box<Expression>),
231 Div(Box<Expression>, Box<Expression>),
233 Neg(Box<Expression>),
235 Real(Box<Expression>),
237 Imag(Box<Expression>),
239 Conj(Box<Expression>),
241 NormSqr(Box<Expression>),
243}
244
245impl Debug for Expression {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 self.write_tree(f, "", "", "")
248 }
249}
250
251impl Display for Expression {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 write!(f, "{:?}", self)
254 }
255}
256
257#[rustfmt::skip]
258impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression {
259 Expression::Add(Box::new(a.clone()), Box::new(b.clone()))
260});
261#[rustfmt::skip]
262impl_op_ex!(- |a: &Expression, b: &Expression| -> Expression {
263 Expression::Sub(Box::new(a.clone()), Box::new(b.clone()))
264});
265#[rustfmt::skip]
266impl_op_ex!(* |a: &Expression, b: &Expression| -> Expression {
267 Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
268});
269#[rustfmt::skip]
270impl_op_ex!(/ |a: &Expression, b: &Expression| -> Expression {
271 Expression::Div(Box::new(a.clone()), Box::new(b.clone()))
272});
273#[rustfmt::skip]
274impl_op_ex!(- |a: &Expression| -> Expression {
275 Expression::Neg(Box::new(a.clone()))
276});
277
278#[rustfmt::skip]
279impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression {
280 Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
281});
282#[rustfmt::skip]
283impl_op_ex_commutative!(- |a: &AmplitudeID, b: &Expression| -> Expression {
284 Expression::Sub(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
285});
286#[rustfmt::skip]
287impl_op_ex_commutative!(* |a: &AmplitudeID, b: &Expression| -> Expression {
288 Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
289});
290#[rustfmt::skip]
291impl_op_ex_commutative!(/ |a: &AmplitudeID, b: &Expression| -> Expression {
292 Expression::Div(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
293});
294
295#[rustfmt::skip]
296impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
297 Expression::Add(
298 Box::new(Expression::Amp(a.clone())),
299 Box::new(Expression::Amp(b.clone()))
300 )
301});
302#[rustfmt::skip]
303impl_op_ex!(- |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
304 Expression::Sub(
305 Box::new(Expression::Amp(a.clone())),
306 Box::new(Expression::Amp(b.clone()))
307 )
308});
309#[rustfmt::skip]
310impl_op_ex!(* |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
311 Expression::Mul(
312 Box::new(Expression::Amp(a.clone())),
313 Box::new(Expression::Amp(b.clone())),
314 )
315});
316#[rustfmt::skip]
317impl_op_ex!(/ |a: &AmplitudeID, b: &AmplitudeID| -> Expression {
318 Expression::Div(
319 Box::new(Expression::Amp(a.clone())),
320 Box::new(Expression::Amp(b.clone())),
321 )
322});
323#[rustfmt::skip]
324impl_op_ex!(- |a: &AmplitudeID| -> Expression {
325 Expression::Neg(
326 Box::new(Expression::Amp(a.clone())),
327 )
328});
329
330impl AmplitudeID {
331 pub fn real(&self) -> Expression {
333 Expression::Real(Box::new(Expression::Amp(self.clone())))
334 }
335 pub fn imag(&self) -> Expression {
337 Expression::Imag(Box::new(Expression::Amp(self.clone())))
338 }
339 pub fn conj(&self) -> Expression {
341 Expression::Conj(Box::new(Expression::Amp(self.clone())))
342 }
343 pub fn norm_sqr(&self) -> Expression {
345 Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
346 }
347}
348
349impl Expression {
350 pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
355 match self {
356 Expression::Amp(aid) => amplitude_values.0[aid.1],
357 Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
358 Expression::Sub(a, b) => a.evaluate(amplitude_values) - b.evaluate(amplitude_values),
359 Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
360 Expression::Div(a, b) => a.evaluate(amplitude_values) / b.evaluate(amplitude_values),
361 Expression::Neg(a) => -a.evaluate(amplitude_values),
362 Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
363 Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
364 Expression::Conj(a) => a.evaluate(amplitude_values).conj(),
365 Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
366 Expression::Zero => Complex::ZERO,
367 Expression::One => Complex::ONE,
368 }
369 }
370 pub fn evaluate_gradient(
375 &self,
376 amplitude_values: &AmplitudeValues,
377 gradient_values: &GradientValues,
378 ) -> DVector<Complex<Float>> {
379 match self {
380 Expression::Amp(aid) => gradient_values.1[aid.1].clone(),
381 Expression::Add(a, b) => {
382 a.evaluate_gradient(amplitude_values, gradient_values)
383 + b.evaluate_gradient(amplitude_values, gradient_values)
384 }
385 Expression::Sub(a, b) => {
386 a.evaluate_gradient(amplitude_values, gradient_values)
387 - b.evaluate_gradient(amplitude_values, gradient_values)
388 }
389 Expression::Mul(a, b) => {
390 let f_a = a.evaluate(amplitude_values);
391 let f_b = b.evaluate(amplitude_values);
392 b.evaluate_gradient(amplitude_values, gradient_values)
393 .map(|g| g * f_a)
394 + a.evaluate_gradient(amplitude_values, gradient_values)
395 .map(|g| g * f_b)
396 }
397 Expression::Div(a, b) => {
398 let f_a = a.evaluate(amplitude_values);
399 let f_b = b.evaluate(amplitude_values);
400 (a.evaluate_gradient(amplitude_values, gradient_values)
401 .map(|g| g * f_b)
402 - b.evaluate_gradient(amplitude_values, gradient_values)
403 .map(|g| g * f_a))
404 / (f_b * f_b)
405 }
406 Expression::Neg(a) => -a.evaluate_gradient(amplitude_values, gradient_values),
407 Expression::Real(a) => a
408 .evaluate_gradient(amplitude_values, gradient_values)
409 .map(|g| Complex::new(g.re, 0.0)),
410 Expression::Imag(a) => a
411 .evaluate_gradient(amplitude_values, gradient_values)
412 .map(|g| Complex::new(g.im, 0.0)),
413 Expression::Conj(a) => a
414 .evaluate_gradient(amplitude_values, gradient_values)
415 .map(|g| g.conj()),
416 Expression::NormSqr(a) => {
417 let conj_f_a = a.evaluate(amplitude_values).conjugate();
418 a.evaluate_gradient(amplitude_values, gradient_values)
419 .map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
420 }
421 Expression::Zero | Expression::One => DVector::zeros(gradient_values.0),
422 }
423 }
424 pub fn real(&self) -> Self {
426 Self::Real(Box::new(self.clone()))
427 }
428 pub fn imag(&self) -> Self {
430 Self::Imag(Box::new(self.clone()))
431 }
432 pub fn conj(&self) -> Self {
434 Self::Conj(Box::new(self.clone()))
435 }
436 pub fn norm_sqr(&self) -> Self {
438 Self::NormSqr(Box::new(self.clone()))
439 }
440
441 fn write_tree(
443 &self,
444 f: &mut std::fmt::Formatter<'_>,
445 parent_prefix: &str,
446 immediate_prefix: &str,
447 parent_suffix: &str,
448 ) -> std::fmt::Result {
449 let display_string = match self {
450 Self::Amp(aid) => aid.to_string(),
451 Self::Add(_, _) => "+".to_string(),
452 Self::Sub(_, _) => "-".to_string(),
453 Self::Mul(_, _) => "×".to_string(),
454 Self::Div(_, _) => "÷".to_string(),
455 Self::Neg(_) => "-".to_string(),
456 Self::Real(_) => "Re".to_string(),
457 Self::Imag(_) => "Im".to_string(),
458 Self::Conj(_) => "*".to_string(),
459 Self::NormSqr(_) => "NormSqr".to_string(),
460 Self::Zero => "0".to_string(),
461 Self::One => "1".to_string(),
462 };
463 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
464 match self {
465 Self::Amp(_) | Self::Zero | Self::One => {}
466 Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) | Self::Div(a, b) => {
467 let terms = [a, b];
468 let mut it = terms.iter().peekable();
469 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
470 while let Some(child) = it.next() {
471 match it.peek() {
472 Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│ "),
473 None => child.write_tree(f, &child_prefix, "└─ ", " "),
474 }?;
475 }
476 }
477 Self::Neg(a) | Self::Real(a) | Self::Imag(a) | Self::Conj(a) | Self::NormSqr(a) => {
478 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
479 a.write_tree(f, &child_prefix, "└─ ", " ")?;
480 }
481 }
482 Ok(())
483 }
484}
485
486#[derive(Default, Clone, Serialize, Deserialize)]
489pub struct Manager {
490 amplitudes: Vec<Box<dyn Amplitude>>,
491 resources: Resources,
492}
493
494impl Manager {
495 pub fn parameters(&self) -> Vec<String> {
497 self.resources.parameters.iter().cloned().collect()
498 }
499 pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
507 let mut amp = amplitude.clone();
508 let aid = amp.register(&mut self.resources)?;
509 self.amplitudes.push(amp);
510 Ok(aid)
511 }
512 pub fn model(&self, expression: &Expression) -> Model {
514 Model {
515 manager: self.clone(),
516 expression: expression.clone(),
517 }
518 }
519}
520
521#[derive(Clone, Serialize, Deserialize)]
527pub struct Model {
528 pub(crate) manager: Manager,
529 pub(crate) expression: Expression,
530}
531
532impl Model {
533 pub fn parameters(&self) -> Vec<String> {
535 self.manager.parameters()
536 }
537 pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
541 let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
542 loaded_resources.write().reserve_cache(dataset.n_events());
543 for amplitude in &self.manager.amplitudes {
544 amplitude.precompute_all(dataset, &mut loaded_resources.write());
545 }
546 Evaluator {
547 amplitudes: self.manager.amplitudes.clone(),
548 resources: loaded_resources.clone(),
549 dataset: dataset.clone(),
550 expression: self.expression.clone(),
551 }
552 }
553}
554
555#[derive(Clone)]
559pub struct Evaluator {
560 pub amplitudes: Vec<Box<dyn Amplitude>>,
564 pub resources: Arc<RwLock<Resources>>,
566 pub dataset: Arc<Dataset>,
568 pub expression: Expression,
570}
571
572impl Evaluator {
573 pub fn parameters(&self) -> Vec<String> {
576 self.resources.read().parameters.iter().cloned().collect()
577 }
578 pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
580 self.resources.write().activate(name)
581 }
582 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
584 self.resources.write().activate_many(names)
585 }
586 pub fn activate_all(&self) {
588 self.resources.write().activate_all();
589 }
590 pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
592 self.resources.write().deactivate(name)
593 }
594 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
596 self.resources.write().deactivate_many(names)
597 }
598 pub fn deactivate_all(&self) {
600 self.resources.write().deactivate_all();
601 }
602 pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
604 self.resources.write().isolate(name)
605 }
606 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
608 self.resources.write().isolate_many(names)
609 }
610
611 pub fn evaluate_local(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
619 let resources = self.resources.read();
620 let parameters = Parameters::new(parameters, &resources.constants);
621 #[cfg(feature = "rayon")]
622 {
623 let amplitude_values_vec: Vec<AmplitudeValues> = self
624 .dataset
625 .events
626 .par_iter()
627 .zip(resources.caches.par_iter())
628 .map(|(event, cache)| {
629 AmplitudeValues(
630 self.amplitudes
631 .iter()
632 .zip(resources.active.iter())
633 .map(|(amp, active)| {
634 if *active {
635 amp.compute(¶meters, event, cache)
636 } else {
637 Complex::new(0.0, 0.0)
638 }
639 })
640 .collect(),
641 )
642 })
643 .collect();
644 amplitude_values_vec
645 .par_iter()
646 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
647 .collect()
648 }
649 #[cfg(not(feature = "rayon"))]
650 {
651 let amplitude_values_vec: Vec<AmplitudeValues> = self
652 .dataset
653 .events
654 .iter()
655 .zip(resources.caches.iter())
656 .map(|(event, cache)| {
657 AmplitudeValues(
658 self.amplitudes
659 .iter()
660 .zip(resources.active.iter())
661 .map(|(amp, active)| {
662 if *active {
663 amp.compute(¶meters, event, cache)
664 } else {
665 Complex::new(0.0, 0.0)
666 }
667 })
668 .collect(),
669 )
670 })
671 .collect();
672 amplitude_values_vec
673 .iter()
674 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
675 .collect()
676 }
677 }
678
679 #[cfg(feature = "mpi")]
687 fn evaluate_mpi(
688 &self,
689 parameters: &[Float],
690 world: &SimpleCommunicator,
691 ) -> Vec<Complex<Float>> {
692 let local_evaluation = self.evaluate_local(parameters);
693 let n_events = self.dataset.n_events();
694 let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; n_events];
695 let (counts, displs) = world.get_counts_displs(n_events);
696 {
697 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
698 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
699 }
700 buffer
701 }
702
703 pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
706 #[cfg(feature = "mpi")]
707 {
708 if let Some(world) = crate::mpi::get_world() {
709 return self.evaluate_mpi(parameters, &world);
710 }
711 }
712 self.evaluate_local(parameters)
713 }
714
715 pub fn evaluate_gradient_local(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
723 let resources = self.resources.read();
724 let parameters = Parameters::new(parameters, &resources.constants);
725 #[cfg(feature = "rayon")]
726 {
727 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
728 .dataset
729 .events
730 .par_iter()
731 .zip(resources.caches.par_iter())
732 .map(|(event, cache)| {
733 let mut gradient_values =
734 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
735 self.amplitudes
736 .iter()
737 .zip(resources.active.iter())
738 .zip(gradient_values.iter_mut())
739 .for_each(|((amp, active), grad)| {
740 if *active {
741 amp.compute_gradient(¶meters, event, cache, grad)
742 }
743 });
744 (
745 AmplitudeValues(
746 self.amplitudes
747 .iter()
748 .zip(resources.active.iter())
749 .map(|(amp, active)| {
750 if *active {
751 amp.compute(¶meters, event, cache)
752 } else {
753 Complex::new(0.0, 0.0)
754 }
755 })
756 .collect(),
757 ),
758 GradientValues(parameters.len(), gradient_values),
759 )
760 })
761 .collect();
762 amplitude_values_and_gradient_vec
763 .par_iter()
764 .map(|(amplitude_values, gradient_values)| {
765 self.expression
766 .evaluate_gradient(amplitude_values, gradient_values)
767 })
768 .collect()
769 }
770 #[cfg(not(feature = "rayon"))]
771 {
772 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
773 .dataset
774 .events
775 .iter()
776 .zip(resources.caches.iter())
777 .map(|(event, cache)| {
778 let mut gradient_values =
779 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
780 self.amplitudes
781 .iter()
782 .zip(resources.active.iter())
783 .zip(gradient_values.iter_mut())
784 .for_each(|((amp, active), grad)| {
785 if *active {
786 amp.compute_gradient(¶meters, event, cache, grad)
787 }
788 });
789 (
790 AmplitudeValues(
791 self.amplitudes
792 .iter()
793 .zip(resources.active.iter())
794 .map(|(amp, active)| {
795 if *active {
796 amp.compute(¶meters, event, cache)
797 } else {
798 Complex::new(0.0, 0.0)
799 }
800 })
801 .collect(),
802 ),
803 GradientValues(parameters.len(), gradient_values),
804 )
805 })
806 .collect();
807
808 amplitude_values_and_gradient_vec
809 .iter()
810 .map(|(amplitude_values, gradient_values)| {
811 self.expression
812 .evaluate_gradient(amplitude_values, gradient_values)
813 })
814 .collect()
815 }
816 }
817
818 #[cfg(feature = "mpi")]
826 fn evaluate_gradient_mpi(
827 &self,
828 parameters: &[Float],
829 world: &SimpleCommunicator,
830 ) -> Vec<DVector<Complex<Float>>> {
831 let flattened_local_evaluation = self
832 .evaluate_gradient_local(parameters)
833 .iter()
834 .flat_map(|g| g.data.as_vec().to_vec())
835 .collect::<Vec<Complex<Float>>>();
836 let n_events = self.dataset.n_events();
837 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
838 let mut flattened_result_buffer = vec![Complex::ZERO; n_events * parameters.len()];
839 let mut partitioned_flattened_result_buffer =
840 PartitionMut::new(&mut flattened_result_buffer, counts, displs);
841 world.all_gather_varcount_into(
842 &flattened_local_evaluation,
843 &mut partitioned_flattened_result_buffer,
844 );
845 flattened_result_buffer
846 .chunks(parameters.len())
847 .map(DVector::from_row_slice)
848 .collect()
849 }
850
851 pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
854 #[cfg(feature = "mpi")]
855 {
856 if let Some(world) = crate::mpi::get_world() {
857 return self.evaluate_gradient_mpi(parameters, &world);
858 }
859 }
860 self.evaluate_gradient_local(parameters)
861 }
862}
863
864#[cfg(test)]
865mod tests {
866 use crate::data::{test_dataset, test_event};
867
868 use super::*;
869 use crate::{
870 data::Event,
871 resources::{Cache, ParameterID, Parameters, Resources},
872 Complex, DVector, Float, LadduError,
873 };
874 use approx::assert_relative_eq;
875 use serde::{Deserialize, Serialize};
876
877 #[derive(Clone, Serialize, Deserialize)]
878 pub struct ComplexScalar {
879 name: String,
880 re: ParameterLike,
881 pid_re: ParameterID,
882 im: ParameterLike,
883 pid_im: ParameterID,
884 }
885
886 impl ComplexScalar {
887 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
888 Self {
889 name: name.to_string(),
890 re,
891 pid_re: Default::default(),
892 im,
893 pid_im: Default::default(),
894 }
895 .into()
896 }
897 }
898
899 #[typetag::serde]
900 impl Amplitude for ComplexScalar {
901 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
902 self.pid_re = resources.register_parameter(&self.re);
903 self.pid_im = resources.register_parameter(&self.im);
904 resources.register_amplitude(&self.name)
905 }
906
907 fn compute(
908 &self,
909 parameters: &Parameters,
910 _event: &Event,
911 _cache: &Cache,
912 ) -> Complex<Float> {
913 Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
914 }
915
916 fn compute_gradient(
917 &self,
918 _parameters: &Parameters,
919 _event: &Event,
920 _cache: &Cache,
921 gradient: &mut DVector<Complex<Float>>,
922 ) {
923 if let ParameterID::Parameter(ind) = self.pid_re {
924 gradient[ind] = Complex::ONE;
925 }
926 if let ParameterID::Parameter(ind) = self.pid_im {
927 gradient[ind] = Complex::I;
928 }
929 }
930 }
931
932 #[test]
933 fn test_constant_amplitude() {
934 let mut manager = Manager::default();
935 let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
936 let aid = manager.register(amp).unwrap();
937 let dataset = Arc::new(Dataset {
938 events: vec![Arc::new(test_event())],
939 });
940 let expr = Expression::Amp(aid);
941 let model = manager.model(&expr);
942 let evaluator = model.load(&dataset);
943 let result = evaluator.evaluate(&[]);
944 assert_eq!(result[0], Complex::new(2.0, 3.0));
945 }
946
947 #[test]
948 fn test_parametric_amplitude() {
949 let mut manager = Manager::default();
950 let amp = ComplexScalar::new(
951 "parametric",
952 parameter("test_param_re"),
953 parameter("test_param_im"),
954 );
955 let aid = manager.register(amp).unwrap();
956 let dataset = Arc::new(test_dataset());
957 let expr = Expression::Amp(aid);
958 let model = manager.model(&expr);
959 let evaluator = model.load(&dataset);
960 let result = evaluator.evaluate(&[2.0, 3.0]);
961 assert_eq!(result[0], Complex::new(2.0, 3.0));
962 }
963
964 #[test]
965 fn test_expression_operations() {
966 let mut manager = Manager::default();
967 let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
968 let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
969 let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
970
971 let aid1 = manager.register(amp1).unwrap();
972 let aid2 = manager.register(amp2).unwrap();
973 let aid3 = manager.register(amp3).unwrap();
974
975 let dataset = Arc::new(test_dataset());
976
977 let expr_add = &aid1 + &aid2;
979 let model_add = manager.model(&expr_add);
980 let eval_add = model_add.load(&dataset);
981 let result_add = eval_add.evaluate(&[]);
982 assert_eq!(result_add[0], Complex::new(2.0, 1.0));
983
984 let expr_sub = &aid1 - &aid2;
986 let model_sub = manager.model(&expr_sub);
987 let eval_sub = model_sub.load(&dataset);
988 let result_sub = eval_sub.evaluate(&[]);
989 assert_eq!(result_sub[0], Complex::new(2.0, -1.0));
990
991 let expr_mul = &aid1 * &aid2;
993 let model_mul = manager.model(&expr_mul);
994 let eval_mul = model_mul.load(&dataset);
995 let result_mul = eval_mul.evaluate(&[]);
996 assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
997
998 let expr_div = &aid1 / &aid3;
1000 let model_div = manager.model(&expr_div);
1001 let eval_div = model_div.load(&dataset);
1002 let result_div = eval_div.evaluate(&[]);
1003 assert_eq!(result_div[0], Complex::new(6.0 / 25.0, -8.0 / 25.0));
1004
1005 let expr_neg = -&aid3;
1007 let model_neg = manager.model(&expr_neg);
1008 let eval_neg = model_neg.load(&dataset);
1009 let result_neg = eval_neg.evaluate(&[]);
1010 assert_eq!(result_neg[0], Complex::new(-3.0, -4.0));
1011
1012 let expr_add2 = &expr_add + &expr_mul;
1014 let model_add2 = manager.model(&expr_add2);
1015 let eval_add2 = model_add2.load(&dataset);
1016 let result_add2 = eval_add2.evaluate(&[]);
1017 assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
1018
1019 let expr_sub2 = &expr_add - &expr_mul;
1021 let model_sub2 = manager.model(&expr_sub2);
1022 let eval_sub2 = model_sub2.load(&dataset);
1023 let result_sub2 = eval_sub2.evaluate(&[]);
1024 assert_eq!(result_sub2[0], Complex::new(2.0, -1.0));
1025
1026 let expr_mul2 = &expr_add * &expr_mul;
1028 let model_mul2 = manager.model(&expr_mul2);
1029 let eval_mul2 = model_mul2.load(&dataset);
1030 let result_mul2 = eval_mul2.evaluate(&[]);
1031 assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
1032
1033 let expr_div2 = &expr_add / &expr_add2;
1035 let model_div2 = manager.model(&expr_div2);
1036 let eval_div2 = model_div2.load(&dataset);
1037 let result_div2 = eval_div2.evaluate(&[]);
1038 assert_eq!(result_div2[0], Complex::new(7.0 / 13.0, -4.0 / 13.0));
1039
1040 let expr_neg2 = -&expr_mul2;
1042 let model_neg2 = manager.model(&expr_neg2);
1043 let eval_neg2 = model_neg2.load(&dataset);
1044 let result_neg2 = eval_neg2.evaluate(&[]);
1045 assert_eq!(result_neg2[0], Complex::new(2.0, -4.0));
1046
1047 let expr_real = aid3.real();
1049 let model_real = manager.model(&expr_real);
1050 let eval_real = model_real.load(&dataset);
1051 let result_real = eval_real.evaluate(&[]);
1052 assert_eq!(result_real[0], Complex::new(3.0, 0.0));
1053
1054 let expr_mul2_real = expr_mul2.real();
1056 let model_mul2_real = manager.model(&expr_mul2_real);
1057 let eval_mul2_real = model_mul2_real.load(&dataset);
1058 let result_mul2_real = eval_mul2_real.evaluate(&[]);
1059 assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
1060
1061 let expr_imag = aid3.imag();
1063 let model_imag = manager.model(&expr_imag);
1064 let eval_imag = model_imag.load(&dataset);
1065 let result_imag = eval_imag.evaluate(&[]);
1066 assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
1067
1068 let expr_mul2_imag = expr_mul2.imag();
1070 let model_mul2_imag = manager.model(&expr_mul2_imag);
1071 let eval_mul2_imag = model_mul2_imag.load(&dataset);
1072 let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
1073 assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
1074
1075 let expr_conj = aid3.conj();
1077 let model_conj = manager.model(&expr_conj);
1078 let eval_conj = model_conj.load(&dataset);
1079 let result_conj = eval_conj.evaluate(&[]);
1080 assert_eq!(result_conj[0], Complex::new(3.0, -4.0));
1081
1082 let expr_mul2_conj = expr_mul2.conj();
1084 let model_mul2_conj = manager.model(&expr_mul2_conj);
1085 let eval_mul2_conj = model_mul2_conj.load(&dataset);
1086 let result_mul2_conj = eval_mul2_conj.evaluate(&[]);
1087 assert_eq!(result_mul2_conj[0], Complex::new(-2.0, -4.0));
1088
1089 let expr_norm = aid1.norm_sqr();
1091 let model_norm = manager.model(&expr_norm);
1092 let eval_norm = model_norm.load(&dataset);
1093 let result_norm = eval_norm.evaluate(&[]);
1094 assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
1095
1096 let expr_mul2_norm = expr_mul2.norm_sqr();
1098 let model_mul2_norm = manager.model(&expr_mul2_norm);
1099 let eval_mul2_norm = model_mul2_norm.load(&dataset);
1100 let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
1101 assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
1102 }
1103
1104 #[test]
1105 fn test_amplitude_activation() {
1106 let mut manager = Manager::default();
1107 let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
1108 let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
1109
1110 let aid1 = manager.register(amp1).unwrap();
1111 let aid2 = manager.register(amp2).unwrap();
1112
1113 let dataset = Arc::new(test_dataset());
1114 let expr = &aid1 + &aid2;
1115 let model = manager.model(&expr);
1116 let evaluator = model.load(&dataset);
1117
1118 let result = evaluator.evaluate(&[]);
1120 assert_eq!(result[0], Complex::new(3.0, 0.0));
1121
1122 evaluator.deactivate("const1").unwrap();
1124 let result = evaluator.evaluate(&[]);
1125 assert_eq!(result[0], Complex::new(2.0, 0.0));
1126
1127 evaluator.isolate("const1").unwrap();
1129 let result = evaluator.evaluate(&[]);
1130 assert_eq!(result[0], Complex::new(1.0, 0.0));
1131
1132 evaluator.activate_all();
1134 let result = evaluator.evaluate(&[]);
1135 assert_eq!(result[0], Complex::new(3.0, 0.0));
1136 }
1137
1138 #[test]
1139 fn test_gradient() {
1140 let mut manager = Manager::default();
1141 let amp1 = ComplexScalar::new(
1142 "parametric_1",
1143 parameter("test_param_re_1"),
1144 parameter("test_param_im_1"),
1145 );
1146 let amp2 = ComplexScalar::new(
1147 "parametric_2",
1148 parameter("test_param_re_2"),
1149 parameter("test_param_im_2"),
1150 );
1151
1152 let aid1 = manager.register(amp1).unwrap();
1153 let aid2 = manager.register(amp2).unwrap();
1154 let dataset = Arc::new(test_dataset());
1155 let params = vec![2.0, 3.0, 4.0, 5.0];
1156
1157 let expr = &aid1 + &aid2;
1158 let model = manager.model(&expr);
1159 let evaluator = model.load(&dataset);
1160
1161 let gradient = evaluator.evaluate_gradient(¶ms);
1162
1163 assert_relative_eq!(gradient[0][0].re, 1.0);
1164 assert_relative_eq!(gradient[0][0].im, 0.0);
1165 assert_relative_eq!(gradient[0][1].re, 0.0);
1166 assert_relative_eq!(gradient[0][1].im, 1.0);
1167 assert_relative_eq!(gradient[0][2].re, 1.0);
1168 assert_relative_eq!(gradient[0][2].im, 0.0);
1169 assert_relative_eq!(gradient[0][3].re, 0.0);
1170 assert_relative_eq!(gradient[0][3].im, 1.0);
1171
1172 let expr = &aid1 - &aid2;
1173 let model = manager.model(&expr);
1174 let evaluator = model.load(&dataset);
1175
1176 let gradient = evaluator.evaluate_gradient(¶ms);
1177
1178 assert_relative_eq!(gradient[0][0].re, 1.0);
1179 assert_relative_eq!(gradient[0][0].im, 0.0);
1180 assert_relative_eq!(gradient[0][1].re, 0.0);
1181 assert_relative_eq!(gradient[0][1].im, 1.0);
1182 assert_relative_eq!(gradient[0][2].re, -1.0);
1183 assert_relative_eq!(gradient[0][2].im, 0.0);
1184 assert_relative_eq!(gradient[0][3].re, 0.0);
1185 assert_relative_eq!(gradient[0][3].im, -1.0);
1186
1187 let expr = &aid1 * &aid2;
1188 let model = manager.model(&expr);
1189 let evaluator = model.load(&dataset);
1190
1191 let gradient = evaluator.evaluate_gradient(¶ms);
1192
1193 assert_relative_eq!(gradient[0][0].re, 4.0);
1194 assert_relative_eq!(gradient[0][0].im, 5.0);
1195 assert_relative_eq!(gradient[0][1].re, -5.0);
1196 assert_relative_eq!(gradient[0][1].im, 4.0);
1197 assert_relative_eq!(gradient[0][2].re, 2.0);
1198 assert_relative_eq!(gradient[0][2].im, 3.0);
1199 assert_relative_eq!(gradient[0][3].re, -3.0);
1200 assert_relative_eq!(gradient[0][3].im, 2.0);
1201
1202 let expr = &aid1 / &aid2;
1203 let model = manager.model(&expr);
1204 let evaluator = model.load(&dataset);
1205
1206 let gradient = evaluator.evaluate_gradient(¶ms);
1207
1208 assert_relative_eq!(gradient[0][0].re, 4.0 / 41.0);
1209 assert_relative_eq!(gradient[0][0].im, -5.0 / 41.0);
1210 assert_relative_eq!(gradient[0][1].re, 5.0 / 41.0);
1211 assert_relative_eq!(gradient[0][1].im, 4.0 / 41.0);
1212 assert_relative_eq!(gradient[0][2].re, -102.0 / 1681.0);
1213 assert_relative_eq!(gradient[0][2].im, 107.0 / 1681.0);
1214 assert_relative_eq!(gradient[0][3].re, -107.0 / 1681.0);
1215 assert_relative_eq!(gradient[0][3].im, -102.0 / 1681.0);
1216
1217 let expr = -(&aid1 * &aid2);
1218 let model = manager.model(&expr);
1219 let evaluator = model.load(&dataset);
1220
1221 let gradient = evaluator.evaluate_gradient(¶ms);
1222
1223 assert_relative_eq!(gradient[0][0].re, -4.0);
1224 assert_relative_eq!(gradient[0][0].im, -5.0);
1225 assert_relative_eq!(gradient[0][1].re, 5.0);
1226 assert_relative_eq!(gradient[0][1].im, -4.0);
1227 assert_relative_eq!(gradient[0][2].re, -2.0);
1228 assert_relative_eq!(gradient[0][2].im, -3.0);
1229 assert_relative_eq!(gradient[0][3].re, 3.0);
1230 assert_relative_eq!(gradient[0][3].im, -2.0);
1231
1232 let expr = (&aid1 * &aid2).real();
1233 let model = manager.model(&expr);
1234 let evaluator = model.load(&dataset);
1235
1236 let gradient = evaluator.evaluate_gradient(¶ms);
1237
1238 assert_relative_eq!(gradient[0][0].re, 4.0);
1239 assert_relative_eq!(gradient[0][0].im, 0.0);
1240 assert_relative_eq!(gradient[0][1].re, -5.0);
1241 assert_relative_eq!(gradient[0][1].im, 0.0);
1242 assert_relative_eq!(gradient[0][2].re, 2.0);
1243 assert_relative_eq!(gradient[0][2].im, 0.0);
1244 assert_relative_eq!(gradient[0][3].re, -3.0);
1245 assert_relative_eq!(gradient[0][3].im, 0.0);
1246
1247 let expr = (&aid1 * &aid2).imag();
1248 let model = manager.model(&expr);
1249 let evaluator = model.load(&dataset);
1250
1251 let gradient = evaluator.evaluate_gradient(¶ms);
1252
1253 assert_relative_eq!(gradient[0][0].re, 5.0);
1254 assert_relative_eq!(gradient[0][0].im, 0.0);
1255 assert_relative_eq!(gradient[0][1].re, 4.0);
1256 assert_relative_eq!(gradient[0][1].im, 0.0);
1257 assert_relative_eq!(gradient[0][2].re, 3.0);
1258 assert_relative_eq!(gradient[0][2].im, 0.0);
1259 assert_relative_eq!(gradient[0][3].re, 2.0);
1260 assert_relative_eq!(gradient[0][3].im, 0.0);
1261
1262 let expr = (&aid1 * &aid2).conj();
1263 let model = manager.model(&expr);
1264 let evaluator = model.load(&dataset);
1265
1266 let gradient = evaluator.evaluate_gradient(¶ms);
1267
1268 assert_relative_eq!(gradient[0][0].re, 4.0);
1269 assert_relative_eq!(gradient[0][0].im, -5.0);
1270 assert_relative_eq!(gradient[0][1].re, -5.0);
1271 assert_relative_eq!(gradient[0][1].im, -4.0);
1272 assert_relative_eq!(gradient[0][2].re, 2.0);
1273 assert_relative_eq!(gradient[0][2].im, -3.0);
1274 assert_relative_eq!(gradient[0][3].re, -3.0);
1275 assert_relative_eq!(gradient[0][3].im, -2.0);
1276
1277 let expr = (&aid1 * &aid2).norm_sqr();
1278 let model = manager.model(&expr);
1279 let evaluator = model.load(&dataset);
1280
1281 let gradient = evaluator.evaluate_gradient(¶ms);
1282
1283 assert_relative_eq!(gradient[0][0].re, 164.0);
1284 assert_relative_eq!(gradient[0][0].im, 0.0);
1285 assert_relative_eq!(gradient[0][1].re, 246.0);
1286 assert_relative_eq!(gradient[0][1].im, 0.0);
1287 assert_relative_eq!(gradient[0][2].re, 104.0);
1288 assert_relative_eq!(gradient[0][2].im, 0.0);
1289 assert_relative_eq!(gradient[0][3].re, 130.0);
1290 assert_relative_eq!(gradient[0][3].im, 0.0);
1291 }
1292
1293 #[test]
1294 fn test_zeros_and_ones() {
1295 let mut manager = Manager::default();
1296 let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1297 let aid = manager.register(amp).unwrap();
1298 let dataset = Arc::new(test_dataset());
1299 let expr = (aid * Expression::One + Expression::Zero).norm_sqr();
1300 let model = manager.model(&expr);
1301 let evaluator = model.load(&dataset);
1302
1303 let params = vec![2.0];
1304 let value = evaluator.evaluate(¶ms);
1305 let gradient = evaluator.evaluate_gradient(¶ms);
1306
1307 assert_relative_eq!(value[0].re, 8.0);
1309 assert_relative_eq!(value[0].im, 0.0);
1310
1311 assert_relative_eq!(gradient[0][0].re, 4.0);
1313 assert_relative_eq!(gradient[0][0].im, 0.0);
1314 }
1315
1316 #[test]
1317 fn test_parameter_registration() {
1318 let mut manager = Manager::default();
1319 let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1320
1321 let aid = manager.register(amp).unwrap();
1322 let parameters = manager.parameters();
1323 let model = manager.model(&aid.into());
1324 let model_parameters = model.parameters();
1325 assert_eq!(parameters.len(), 1);
1326 assert_eq!(parameters[0], "test_param_re");
1327 assert_eq!(model_parameters.len(), 1);
1328 assert_eq!(model_parameters[0], "test_param_re");
1329 }
1330
1331 #[test]
1332 fn test_duplicate_amplitude_registration() {
1333 let mut manager = Manager::default();
1334 let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
1335 let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
1336 manager.register(amp1).unwrap();
1337 assert!(manager.register(amp2).is_err());
1338 }
1339
1340 #[test]
1341 fn test_tree_printing() {
1342 let mut manager = Manager::default();
1343 let amp1 = ComplexScalar::new(
1344 "parametric_1",
1345 parameter("test_param_re_1"),
1346 parameter("test_param_im_1"),
1347 );
1348 let amp2 = ComplexScalar::new(
1349 "parametric_2",
1350 parameter("test_param_re_2"),
1351 parameter("test_param_im_2"),
1352 );
1353 let aid1 = manager.register(amp1).unwrap();
1354 let aid2 = manager.register(amp2).unwrap();
1355 let expr = &aid1.real() + &aid2.conj().imag() + Expression::One * -Expression::Zero
1356 - Expression::Zero / Expression::One
1357 + (&aid1 * &aid2).norm_sqr();
1358 assert_eq!(
1359 expr.to_string(),
1360 "+
1361├─ -
1362│ ├─ +
1363│ │ ├─ +
1364│ │ │ ├─ Re
1365│ │ │ │ └─ parametric_1(id=0)
1366│ │ │ └─ Im
1367│ │ │ └─ *
1368│ │ │ └─ parametric_2(id=1)
1369│ │ └─ ×
1370│ │ ├─ 1
1371│ │ └─ -
1372│ │ └─ 0
1373│ └─ ÷
1374│ ├─ 0
1375│ └─ 1
1376└─ NormSqr
1377 └─ ×
1378 ├─ parametric_1(id=0)
1379 └─ parametric_2(id=1)
1380"
1381 );
1382 }
1383}