1use std::{
2 fmt::{Debug, Display},
3 sync::Arc,
4};
5
6use auto_ops::*;
7use dyn_clone::DynClone;
8use nalgebra::{ComplexField, DVector};
9use num::Complex;
10
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use crate::{
17 data::{Dataset, Event},
18 resources::{Cache, Parameters, Resources},
19 Float, LadduError,
20};
21
22#[cfg(feature = "mpi")]
23use crate::mpi::LadduMPI;
24
25#[cfg(feature = "mpi")]
26use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
27
28#[derive(Clone, Default, Serialize, Deserialize)]
30pub enum ParameterLike {
31 Parameter(String),
33 Constant(Float),
35 #[default]
38 Uninit,
39}
40
41pub fn parameter(name: &str) -> ParameterLike {
43 ParameterLike::Parameter(name.to_string())
44}
45
46pub fn constant(value: Float) -> ParameterLike {
48 ParameterLike::Constant(value)
49}
50
51#[typetag::serde(tag = "type")]
59pub trait Amplitude: DynClone + Send + Sync {
60 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
65 #[allow(unused_variables)]
70 fn precompute(&self, event: &Event, cache: &mut Cache) {}
71 #[cfg(feature = "rayon")]
73 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
74 dataset
75 .events
76 .par_iter()
77 .zip(resources.caches.par_iter_mut())
78 .for_each(|(event, cache)| {
79 self.precompute(event, cache);
80 })
81 }
82 #[cfg(not(feature = "rayon"))]
84 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
85 dataset
86 .events
87 .iter()
88 .zip(resources.caches.iter_mut())
89 .for_each(|(event, cache)| self.precompute(event, cache))
90 }
91 fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
100
101 fn compute_gradient(
117 &self,
118 parameters: &Parameters,
119 event: &Event,
120 cache: &Cache,
121 gradient: &mut DVector<Complex<Float>>,
122 ) {
123 self.central_difference_with_indices(
124 &Vec::from_iter(0..parameters.len()),
125 parameters,
126 event,
127 cache,
128 gradient,
129 )
130 }
131
132 fn central_difference_with_indices(
138 &self,
139 indices: &[usize],
140 parameters: &Parameters,
141 event: &Event,
142 cache: &Cache,
143 gradient: &mut DVector<Complex<Float>>,
144 ) {
145 let x = parameters.parameters.to_owned();
146 let constants = parameters.constants.to_owned();
147 let h: DVector<Float> = x
148 .iter()
149 .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
150 .collect::<Vec<_>>()
151 .into();
152 for i in indices {
153 let mut x_plus = x.clone();
154 let mut x_minus = x.clone();
155 x_plus[*i] += h[*i];
156 x_minus[*i] -= h[*i];
157 let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
158 let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
159 gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
160 }
161 }
162}
163
164pub fn central_difference<F: Fn(&[Float]) -> Float>(
166 parameters: &[Float],
167 func: F,
168) -> DVector<Float> {
169 let mut gradient = DVector::zeros(parameters.len());
170 let x = parameters.to_owned();
171 let h: DVector<Float> = x
172 .iter()
173 .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
174 .collect::<Vec<_>>()
175 .into();
176 for i in 0..parameters.len() {
177 let mut x_plus = x.clone();
178 let mut x_minus = x.clone();
179 x_plus[i] += h[i];
180 x_minus[i] -= h[i];
181 let f_plus = func(&x_plus);
182 let f_minus = func(&x_minus);
183 gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
184 }
185 gradient
186}
187
188dyn_clone::clone_trait_object!(Amplitude);
189
190#[derive(Debug)]
192pub struct AmplitudeValues(pub Vec<Complex<Float>>);
193
194#[derive(Debug)]
196pub struct GradientValues(pub usize, pub Vec<DVector<Complex<Float>>>);
197
198#[derive(Clone, Default, Debug, Serialize, Deserialize)]
201pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
202
203impl Display for AmplitudeID {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 write!(f, "{}(id={})", self.0, self.1)
206 }
207}
208
209impl From<AmplitudeID> for Expression {
210 fn from(value: AmplitudeID) -> Self {
211 Self::Amp(value)
212 }
213}
214
215#[derive(Clone, Serialize, Deserialize, Default)]
217pub enum Expression {
218 #[default]
219 Zero,
221 One,
223 Amp(AmplitudeID),
225 Add(Box<Expression>, Box<Expression>),
227 Mul(Box<Expression>, Box<Expression>),
229 Real(Box<Expression>),
231 Imag(Box<Expression>),
233 NormSqr(Box<Expression>),
235}
236
237impl Debug for Expression {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 self.write_tree(f, "", "", "")
240 }
241}
242
243impl Display for Expression {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 write!(f, "{:?}", self)
246 }
247}
248
249impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression { Expression::Add(Box::new(a.clone()), Box::new(b.clone()))});
250impl_op_ex!(*|a: &Expression, b: &Expression| -> Expression {
251 Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
252});
253impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))});
254impl_op_ex_commutative!(*|a: &AmplitudeID, b: &Expression| -> Expression {
255 Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
256});
257impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(Expression::Amp(b.clone())))});
258impl_op_ex!(*|a: &AmplitudeID, b: &AmplitudeID| -> Expression {
259 Expression::Mul(
260 Box::new(Expression::Amp(a.clone())),
261 Box::new(Expression::Amp(b.clone())),
262 )
263});
264
265impl AmplitudeID {
266 pub fn real(&self) -> Expression {
268 Expression::Real(Box::new(Expression::Amp(self.clone())))
269 }
270 pub fn imag(&self) -> Expression {
272 Expression::Imag(Box::new(Expression::Amp(self.clone())))
273 }
274 pub fn norm_sqr(&self) -> Expression {
276 Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
277 }
278}
279
280impl Expression {
281 pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
286 match self {
287 Expression::Amp(aid) => amplitude_values.0[aid.1],
288 Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
289 Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
290 Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
291 Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
292 Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
293 Expression::Zero => Complex::ZERO,
294 Expression::One => Complex::ONE,
295 }
296 }
297 pub fn evaluate_gradient(
302 &self,
303 amplitude_values: &AmplitudeValues,
304 gradient_values: &GradientValues,
305 ) -> DVector<Complex<Float>> {
306 match self {
307 Expression::Amp(aid) => gradient_values.1[aid.1].clone(),
308 Expression::Add(a, b) => {
309 a.evaluate_gradient(amplitude_values, gradient_values)
310 + b.evaluate_gradient(amplitude_values, gradient_values)
311 }
312 Expression::Mul(a, b) => {
313 let f_a = a.evaluate(amplitude_values);
314 let f_b = b.evaluate(amplitude_values);
315 b.evaluate_gradient(amplitude_values, gradient_values)
316 .map(|g| g * f_a)
317 + a.evaluate_gradient(amplitude_values, gradient_values)
318 .map(|g| g * f_b)
319 }
320 Expression::Real(a) => a
321 .evaluate_gradient(amplitude_values, gradient_values)
322 .map(|g| Complex::new(g.re, 0.0)),
323 Expression::Imag(a) => a
324 .evaluate_gradient(amplitude_values, gradient_values)
325 .map(|g| Complex::new(g.im, 0.0)),
326 Expression::NormSqr(a) => {
327 let conj_f_a = a.evaluate(amplitude_values).conjugate();
328 a.evaluate_gradient(amplitude_values, gradient_values)
329 .map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
330 }
331 Expression::Zero | Expression::One => DVector::zeros(gradient_values.0),
332 }
333 }
334 pub fn real(&self) -> Self {
336 Self::Real(Box::new(self.clone()))
337 }
338 pub fn imag(&self) -> Self {
340 Self::Imag(Box::new(self.clone()))
341 }
342 pub fn norm_sqr(&self) -> Self {
344 Self::NormSqr(Box::new(self.clone()))
345 }
346
347 fn write_tree(
349 &self,
350 f: &mut std::fmt::Formatter<'_>,
351 parent_prefix: &str,
352 immediate_prefix: &str,
353 parent_suffix: &str,
354 ) -> std::fmt::Result {
355 let display_string = match self {
356 Self::Amp(aid) => aid.to_string(),
357 Self::Add(_, _) => "+".to_string(),
358 Self::Mul(_, _) => "*".to_string(),
359 Self::Real(_) => "Re".to_string(),
360 Self::Imag(_) => "Im".to_string(),
361 Self::NormSqr(_) => "NormSqr".to_string(),
362 Self::Zero => "0".to_string(),
363 Self::One => "1".to_string(),
364 };
365 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
366 match self {
367 Self::Amp(_) | Self::Zero | Self::One => {}
368 Self::Add(a, b) | Self::Mul(a, b) => {
369 let terms = [a, b];
370 let mut it = terms.iter().peekable();
371 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
372 while let Some(child) = it.next() {
373 match it.peek() {
374 Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│ "),
375 None => child.write_tree(f, &child_prefix, "└─ ", " "),
376 }?;
377 }
378 }
379 Self::Real(a) | Self::Imag(a) | Self::NormSqr(a) => {
380 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
381 a.write_tree(f, &child_prefix, "└─ ", " ")?;
382 }
383 }
384 Ok(())
385 }
386}
387
388#[derive(Default, Clone, Serialize, Deserialize)]
391pub struct Manager {
392 amplitudes: Vec<Box<dyn Amplitude>>,
393 resources: Resources,
394}
395
396impl Manager {
397 pub fn parameters(&self) -> Vec<String> {
399 self.resources.parameters.iter().cloned().collect()
400 }
401 pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
409 let mut amp = amplitude.clone();
410 let aid = amp.register(&mut self.resources)?;
411 self.amplitudes.push(amp);
412 Ok(aid)
413 }
414 pub fn model(&self, expression: &Expression) -> Model {
416 Model {
417 manager: self.clone(),
418 expression: expression.clone(),
419 }
420 }
421}
422
423#[derive(Clone, Serialize, Deserialize)]
429pub struct Model {
430 pub(crate) manager: Manager,
431 pub(crate) expression: Expression,
432}
433
434impl Model {
435 pub fn parameters(&self) -> Vec<String> {
437 self.manager.parameters()
438 }
439 pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
443 let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
444 loaded_resources.write().reserve_cache(dataset.n_events());
445 for amplitude in &self.manager.amplitudes {
446 amplitude.precompute_all(dataset, &mut loaded_resources.write());
447 }
448 Evaluator {
449 amplitudes: self.manager.amplitudes.clone(),
450 resources: loaded_resources.clone(),
451 dataset: dataset.clone(),
452 expression: self.expression.clone(),
453 }
454 }
455}
456
457#[derive(Clone)]
461pub struct Evaluator {
462 pub amplitudes: Vec<Box<dyn Amplitude>>,
466 pub resources: Arc<RwLock<Resources>>,
468 pub dataset: Arc<Dataset>,
470 pub expression: Expression,
472}
473
474impl Evaluator {
475 pub fn parameters(&self) -> Vec<String> {
478 self.resources.read().parameters.iter().cloned().collect()
479 }
480 pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
482 self.resources.write().activate(name)
483 }
484 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
486 self.resources.write().activate_many(names)
487 }
488 pub fn activate_all(&self) {
490 self.resources.write().activate_all();
491 }
492 pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
494 self.resources.write().deactivate(name)
495 }
496 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
498 self.resources.write().deactivate_many(names)
499 }
500 pub fn deactivate_all(&self) {
502 self.resources.write().deactivate_all();
503 }
504 pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
506 self.resources.write().isolate(name)
507 }
508 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
510 self.resources.write().isolate_many(names)
511 }
512
513 pub fn evaluate_local(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
521 let resources = self.resources.read();
522 let parameters = Parameters::new(parameters, &resources.constants);
523 #[cfg(feature = "rayon")]
524 {
525 let amplitude_values_vec: Vec<AmplitudeValues> = self
526 .dataset
527 .events
528 .par_iter()
529 .zip(resources.caches.par_iter())
530 .map(|(event, cache)| {
531 AmplitudeValues(
532 self.amplitudes
533 .iter()
534 .zip(resources.active.iter())
535 .map(|(amp, active)| {
536 if *active {
537 amp.compute(¶meters, event, cache)
538 } else {
539 Complex::new(0.0, 0.0)
540 }
541 })
542 .collect(),
543 )
544 })
545 .collect();
546 amplitude_values_vec
547 .par_iter()
548 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
549 .collect()
550 }
551 #[cfg(not(feature = "rayon"))]
552 {
553 let amplitude_values_vec: Vec<AmplitudeValues> = self
554 .dataset
555 .events
556 .iter()
557 .zip(resources.caches.iter())
558 .map(|(event, cache)| {
559 AmplitudeValues(
560 self.amplitudes
561 .iter()
562 .zip(resources.active.iter())
563 .map(|(amp, active)| {
564 if *active {
565 amp.compute(¶meters, event, cache)
566 } else {
567 Complex::new(0.0, 0.0)
568 }
569 })
570 .collect(),
571 )
572 })
573 .collect();
574 amplitude_values_vec
575 .iter()
576 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
577 .collect()
578 }
579 }
580
581 #[cfg(feature = "mpi")]
589 fn evaluate_mpi(
590 &self,
591 parameters: &[Float],
592 world: &SimpleCommunicator,
593 ) -> Vec<Complex<Float>> {
594 let local_evaluation = self.evaluate_local(parameters);
595 let n_events = self.dataset.n_events();
596 let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; n_events];
597 let (counts, displs) = world.get_counts_displs(n_events);
598 {
599 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
600 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
601 }
602 buffer
603 }
604
605 pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
608 #[cfg(feature = "mpi")]
609 {
610 if let Some(world) = crate::mpi::get_world() {
611 return self.evaluate_mpi(parameters, &world);
612 }
613 }
614 self.evaluate_local(parameters)
615 }
616
617 pub fn evaluate_gradient_local(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
625 let resources = self.resources.read();
626 let parameters = Parameters::new(parameters, &resources.constants);
627 #[cfg(feature = "rayon")]
628 {
629 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
630 .dataset
631 .events
632 .par_iter()
633 .zip(resources.caches.par_iter())
634 .map(|(event, cache)| {
635 let mut gradient_values =
636 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
637 self.amplitudes
638 .iter()
639 .zip(resources.active.iter())
640 .zip(gradient_values.iter_mut())
641 .for_each(|((amp, active), grad)| {
642 if *active {
643 amp.compute_gradient(¶meters, event, cache, grad)
644 }
645 });
646 (
647 AmplitudeValues(
648 self.amplitudes
649 .iter()
650 .zip(resources.active.iter())
651 .map(|(amp, active)| {
652 if *active {
653 amp.compute(¶meters, event, cache)
654 } else {
655 Complex::new(0.0, 0.0)
656 }
657 })
658 .collect(),
659 ),
660 GradientValues(parameters.len(), gradient_values),
661 )
662 })
663 .collect();
664 amplitude_values_and_gradient_vec
665 .par_iter()
666 .map(|(amplitude_values, gradient_values)| {
667 self.expression
668 .evaluate_gradient(amplitude_values, gradient_values)
669 })
670 .collect()
671 }
672 #[cfg(not(feature = "rayon"))]
673 {
674 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
675 .dataset
676 .events
677 .iter()
678 .zip(resources.caches.iter())
679 .map(|(event, cache)| {
680 let mut gradient_values =
681 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
682 self.amplitudes
683 .iter()
684 .zip(resources.active.iter())
685 .zip(gradient_values.iter_mut())
686 .for_each(|((amp, active), grad)| {
687 if *active {
688 amp.compute_gradient(¶meters, event, cache, grad)
689 }
690 });
691 (
692 AmplitudeValues(
693 self.amplitudes
694 .iter()
695 .zip(resources.active.iter())
696 .map(|(amp, active)| {
697 if *active {
698 amp.compute(¶meters, event, cache)
699 } else {
700 Complex::new(0.0, 0.0)
701 }
702 })
703 .collect(),
704 ),
705 GradientValues(parameters.len(), gradient_values),
706 )
707 })
708 .collect();
709
710 amplitude_values_and_gradient_vec
711 .iter()
712 .map(|(amplitude_values, gradient_values)| {
713 self.expression
714 .evaluate_gradient(amplitude_values, gradient_values)
715 })
716 .collect()
717 }
718 }
719
720 #[cfg(feature = "mpi")]
728 fn evaluate_gradient_mpi(
729 &self,
730 parameters: &[Float],
731 world: &SimpleCommunicator,
732 ) -> Vec<DVector<Complex<Float>>> {
733 let flattened_local_evaluation = self
734 .evaluate_gradient_local(parameters)
735 .iter()
736 .flat_map(|g| g.data.as_vec().to_vec())
737 .collect::<Vec<Complex<Float>>>();
738 let n_events = self.dataset.n_events();
739 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
740 let mut flattened_result_buffer = vec![Complex::ZERO; n_events * parameters.len()];
741 let mut partitioned_flattened_result_buffer =
742 PartitionMut::new(&mut flattened_result_buffer, counts, displs);
743 world.all_gather_varcount_into(
744 &flattened_local_evaluation,
745 &mut partitioned_flattened_result_buffer,
746 );
747 flattened_result_buffer
748 .chunks(parameters.len())
749 .map(DVector::from_row_slice)
750 .collect()
751 }
752
753 pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
756 #[cfg(feature = "mpi")]
757 {
758 if let Some(world) = crate::mpi::get_world() {
759 return self.evaluate_gradient_mpi(parameters, &world);
760 }
761 }
762 self.evaluate_gradient_local(parameters)
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 use crate::data::{test_dataset, test_event};
769
770 use super::*;
771 use crate::{
772 data::Event,
773 resources::{Cache, ParameterID, Parameters, Resources},
774 Complex, DVector, Float, LadduError,
775 };
776 use approx::assert_relative_eq;
777 use serde::{Deserialize, Serialize};
778
779 #[derive(Clone, Serialize, Deserialize)]
780 pub struct ComplexScalar {
781 name: String,
782 re: ParameterLike,
783 pid_re: ParameterID,
784 im: ParameterLike,
785 pid_im: ParameterID,
786 }
787
788 impl ComplexScalar {
789 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
790 Self {
791 name: name.to_string(),
792 re,
793 pid_re: Default::default(),
794 im,
795 pid_im: Default::default(),
796 }
797 .into()
798 }
799 }
800
801 #[typetag::serde]
802 impl Amplitude for ComplexScalar {
803 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
804 self.pid_re = resources.register_parameter(&self.re);
805 self.pid_im = resources.register_parameter(&self.im);
806 resources.register_amplitude(&self.name)
807 }
808
809 fn compute(
810 &self,
811 parameters: &Parameters,
812 _event: &Event,
813 _cache: &Cache,
814 ) -> Complex<Float> {
815 Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
816 }
817
818 fn compute_gradient(
819 &self,
820 _parameters: &Parameters,
821 _event: &Event,
822 _cache: &Cache,
823 gradient: &mut DVector<Complex<Float>>,
824 ) {
825 if let ParameterID::Parameter(ind) = self.pid_re {
826 gradient[ind] = Complex::ONE;
827 }
828 if let ParameterID::Parameter(ind) = self.pid_im {
829 gradient[ind] = Complex::I;
830 }
831 }
832 }
833
834 #[test]
835 fn test_constant_amplitude() {
836 let mut manager = Manager::default();
837 let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
838 let aid = manager.register(amp).unwrap();
839 let dataset = Arc::new(Dataset {
840 events: vec![Arc::new(test_event())],
841 });
842 let expr = Expression::Amp(aid);
843 let model = manager.model(&expr);
844 let evaluator = model.load(&dataset);
845 let result = evaluator.evaluate(&[]);
846 assert_eq!(result[0], Complex::new(2.0, 3.0));
847 }
848
849 #[test]
850 fn test_parametric_amplitude() {
851 let mut manager = Manager::default();
852 let amp = ComplexScalar::new(
853 "parametric",
854 parameter("test_param_re"),
855 parameter("test_param_im"),
856 );
857 let aid = manager.register(amp).unwrap();
858 let dataset = Arc::new(test_dataset());
859 let expr = Expression::Amp(aid);
860 let model = manager.model(&expr);
861 let evaluator = model.load(&dataset);
862 let result = evaluator.evaluate(&[2.0, 3.0]);
863 assert_eq!(result[0], Complex::new(2.0, 3.0));
864 }
865
866 #[test]
867 fn test_expression_operations() {
868 let mut manager = Manager::default();
869 let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
870 let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
871 let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
872
873 let aid1 = manager.register(amp1).unwrap();
874 let aid2 = manager.register(amp2).unwrap();
875 let aid3 = manager.register(amp3).unwrap();
876
877 let dataset = Arc::new(test_dataset());
878
879 let expr_add = &aid1 + &aid2;
881 let model_add = manager.model(&expr_add);
882 let eval_add = model_add.load(&dataset);
883 let result_add = eval_add.evaluate(&[]);
884 assert_eq!(result_add[0], Complex::new(2.0, 1.0));
885
886 let expr_mul = &aid1 * &aid2;
888 let model_mul = manager.model(&expr_mul);
889 let eval_mul = model_mul.load(&dataset);
890 let result_mul = eval_mul.evaluate(&[]);
891 assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
892
893 let expr_add2 = &expr_add + &expr_mul;
895 let model_add2 = manager.model(&expr_add2);
896 let eval_add2 = model_add2.load(&dataset);
897 let result_add2 = eval_add2.evaluate(&[]);
898 assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
899
900 let expr_mul2 = &expr_add * &expr_mul;
902 let model_mul2 = manager.model(&expr_mul2);
903 let eval_mul2 = model_mul2.load(&dataset);
904 let result_mul2 = eval_mul2.evaluate(&[]);
905 assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
906
907 let expr_real = aid3.real();
909 let model_real = manager.model(&expr_real);
910 let eval_real = model_real.load(&dataset);
911 let result_real = eval_real.evaluate(&[]);
912 assert_eq!(result_real[0], Complex::new(3.0, 0.0));
913
914 let expr_mul2_real = expr_mul2.real();
916 let model_mul2_real = manager.model(&expr_mul2_real);
917 let eval_mul2_real = model_mul2_real.load(&dataset);
918 let result_mul2_real = eval_mul2_real.evaluate(&[]);
919 assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
920
921 let expr_mul2_imag = expr_mul2.imag();
923 let model_mul2_imag = manager.model(&expr_mul2_imag);
924 let eval_mul2_imag = model_mul2_imag.load(&dataset);
925 let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
926 assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
927
928 let expr_imag = aid3.imag();
930 let model_imag = manager.model(&expr_imag);
931 let eval_imag = model_imag.load(&dataset);
932 let result_imag = eval_imag.evaluate(&[]);
933 assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
934
935 let expr_norm = aid1.norm_sqr();
937 let model_norm = manager.model(&expr_norm);
938 let eval_norm = model_norm.load(&dataset);
939 let result_norm = eval_norm.evaluate(&[]);
940 assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
941
942 let expr_mul2_norm = expr_mul2.norm_sqr();
944 let model_mul2_norm = manager.model(&expr_mul2_norm);
945 let eval_mul2_norm = model_mul2_norm.load(&dataset);
946 let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
947 assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
948 }
949
950 #[test]
951 fn test_amplitude_activation() {
952 let mut manager = Manager::default();
953 let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
954 let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
955
956 let aid1 = manager.register(amp1).unwrap();
957 let aid2 = manager.register(amp2).unwrap();
958
959 let dataset = Arc::new(test_dataset());
960 let expr = &aid1 + &aid2;
961 let model = manager.model(&expr);
962 let evaluator = model.load(&dataset);
963
964 let result = evaluator.evaluate(&[]);
966 assert_eq!(result[0], Complex::new(3.0, 0.0));
967
968 evaluator.deactivate("const1").unwrap();
970 let result = evaluator.evaluate(&[]);
971 assert_eq!(result[0], Complex::new(2.0, 0.0));
972
973 evaluator.isolate("const1").unwrap();
975 let result = evaluator.evaluate(&[]);
976 assert_eq!(result[0], Complex::new(1.0, 0.0));
977
978 evaluator.activate_all();
980 let result = evaluator.evaluate(&[]);
981 assert_eq!(result[0], Complex::new(3.0, 0.0));
982 }
983
984 #[test]
985 fn test_gradient() {
986 let mut manager = Manager::default();
987 let amp1 = ComplexScalar::new(
988 "parametric_1",
989 parameter("test_param_re_1"),
990 parameter("test_param_im_1"),
991 );
992 let amp2 = ComplexScalar::new(
993 "parametric_2",
994 parameter("test_param_re_2"),
995 parameter("test_param_im_2"),
996 );
997
998 let aid1 = manager.register(amp1).unwrap();
999 let aid2 = manager.register(amp2).unwrap();
1000 let dataset = Arc::new(test_dataset());
1001 let params = vec![2.0, 3.0, 4.0, 5.0];
1002
1003 let expr = &aid1 * &aid2;
1004 let model = manager.model(&expr);
1005 let evaluator = model.load(&dataset);
1006
1007 let gradient = evaluator.evaluate_gradient(¶ms);
1008
1009 assert_relative_eq!(gradient[0][0].re, 4.0);
1010 assert_relative_eq!(gradient[0][0].im, 5.0);
1011 assert_relative_eq!(gradient[0][1].re, -5.0);
1012 assert_relative_eq!(gradient[0][1].im, 4.0);
1013 assert_relative_eq!(gradient[0][2].re, 2.0);
1014 assert_relative_eq!(gradient[0][2].im, 3.0);
1015 assert_relative_eq!(gradient[0][3].re, -3.0);
1016 assert_relative_eq!(gradient[0][3].im, 2.0);
1017
1018 let expr = (&aid1 * &aid2).real();
1019 let model = manager.model(&expr);
1020 let evaluator = model.load(&dataset);
1021
1022 let gradient = evaluator.evaluate_gradient(¶ms);
1023
1024 assert_relative_eq!(gradient[0][0].re, 4.0);
1025 assert_relative_eq!(gradient[0][0].im, 0.0);
1026 assert_relative_eq!(gradient[0][1].re, -5.0);
1027 assert_relative_eq!(gradient[0][1].im, 0.0);
1028 assert_relative_eq!(gradient[0][2].re, 2.0);
1029 assert_relative_eq!(gradient[0][2].im, 0.0);
1030 assert_relative_eq!(gradient[0][3].re, -3.0);
1031 assert_relative_eq!(gradient[0][3].im, 0.0);
1032
1033 let expr = (&aid1 * &aid2).imag();
1034 let model = manager.model(&expr);
1035 let evaluator = model.load(&dataset);
1036
1037 let gradient = evaluator.evaluate_gradient(¶ms);
1038
1039 assert_relative_eq!(gradient[0][0].re, 5.0);
1040 assert_relative_eq!(gradient[0][0].im, 0.0);
1041 assert_relative_eq!(gradient[0][1].re, 4.0);
1042 assert_relative_eq!(gradient[0][1].im, 0.0);
1043 assert_relative_eq!(gradient[0][2].re, 3.0);
1044 assert_relative_eq!(gradient[0][2].im, 0.0);
1045 assert_relative_eq!(gradient[0][3].re, 2.0);
1046 assert_relative_eq!(gradient[0][3].im, 0.0);
1047
1048 let expr = (&aid1 * &aid2).norm_sqr();
1049 let model = manager.model(&expr);
1050 let evaluator = model.load(&dataset);
1051
1052 let gradient = evaluator.evaluate_gradient(¶ms);
1053
1054 assert_relative_eq!(gradient[0][0].re, 164.0);
1055 assert_relative_eq!(gradient[0][0].im, 0.0);
1056 assert_relative_eq!(gradient[0][1].re, 246.0);
1057 assert_relative_eq!(gradient[0][1].im, 0.0);
1058 assert_relative_eq!(gradient[0][2].re, 104.0);
1059 assert_relative_eq!(gradient[0][2].im, 0.0);
1060 assert_relative_eq!(gradient[0][3].re, 130.0);
1061 assert_relative_eq!(gradient[0][3].im, 0.0);
1062 }
1063
1064 #[test]
1065 fn test_zeros_and_ones() {
1066 let mut manager = Manager::default();
1067 let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1068 let aid = manager.register(amp).unwrap();
1069 let dataset = Arc::new(test_dataset());
1070 let expr = (aid * Expression::One + Expression::Zero).norm_sqr();
1071 let model = manager.model(&expr);
1072 let evaluator = model.load(&dataset);
1073
1074 let params = vec![2.0];
1075 let value = evaluator.evaluate(¶ms);
1076 let gradient = evaluator.evaluate_gradient(¶ms);
1077
1078 assert_relative_eq!(value[0].re, 8.0);
1080 assert_relative_eq!(value[0].im, 0.0);
1081
1082 assert_relative_eq!(gradient[0][0].re, 4.0);
1084 assert_relative_eq!(gradient[0][0].im, 0.0);
1085 }
1086
1087 #[test]
1088 fn test_parameter_registration() {
1089 let mut manager = Manager::default();
1090 let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1091
1092 let aid = manager.register(amp).unwrap();
1093 let parameters = manager.parameters();
1094 let model = manager.model(&aid.into());
1095 let model_parameters = model.parameters();
1096 assert_eq!(parameters.len(), 1);
1097 assert_eq!(parameters[0], "test_param_re");
1098 assert_eq!(model_parameters.len(), 1);
1099 assert_eq!(model_parameters[0], "test_param_re");
1100 }
1101
1102 #[test]
1103 fn test_duplicate_amplitude_registration() {
1104 let mut manager = Manager::default();
1105 let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
1106 let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
1107 manager.register(amp1).unwrap();
1108 assert!(manager.register(amp2).is_err());
1109 }
1110
1111 #[test]
1112 fn test_tree_printing() {
1113 let mut manager = Manager::default();
1114 let amp1 = ComplexScalar::new(
1115 "parametric_1",
1116 parameter("test_param_re_1"),
1117 parameter("test_param_im_1"),
1118 );
1119 let amp2 = ComplexScalar::new(
1120 "parametric_2",
1121 parameter("test_param_re_2"),
1122 parameter("test_param_im_2"),
1123 );
1124 let aid1 = manager.register(amp1).unwrap();
1125 let aid2 = manager.register(amp2).unwrap();
1126 let expr = &aid1.real()
1127 + &aid2.imag()
1128 + Expression::One * Expression::Zero
1129 + (&aid1 * &aid2).norm_sqr();
1130 assert_eq!(
1131 expr.to_string(),
1132 "+
1133├─ +
1134│ ├─ +
1135│ │ ├─ Re
1136│ │ │ └─ parametric_1(id=0)
1137│ │ └─ Im
1138│ │ └─ parametric_2(id=1)
1139│ └─ *
1140│ ├─ 1
1141│ └─ 0
1142└─ NormSqr
1143 └─ *
1144 ├─ parametric_1(id=0)
1145 └─ parametric_2(id=1)
1146"
1147 );
1148 }
1149}