1use std::{
2 fmt::{Debug, Display},
3 sync::Arc,
4};
5
6use auto_ops::*;
7use dyn_clone::DynClone;
8use nalgebra::{ComplexField, DVector};
9use num::Complex;
10
11use parking_lot::RwLock;
12#[cfg(feature = "rayon")]
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15
16use crate::{
17 data::{Dataset, Event},
18 resources::{Cache, Parameters, Resources},
19 Float, LadduError,
20};
21
22#[cfg(feature = "mpi")]
23use crate::mpi::LadduMPI;
24
25#[cfg(feature = "mpi")]
26use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
27
28#[derive(Clone, Default, Serialize, Deserialize)]
30pub enum ParameterLike {
31 Parameter(String),
33 Constant(Float),
35 #[default]
38 Uninit,
39}
40
41pub fn parameter(name: &str) -> ParameterLike {
43 ParameterLike::Parameter(name.to_string())
44}
45
46pub fn constant(value: Float) -> ParameterLike {
48 ParameterLike::Constant(value)
49}
50
51#[typetag::serde(tag = "type")]
59pub trait Amplitude: DynClone + Send + Sync {
60 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
65 #[allow(unused_variables)]
70 fn precompute(&self, event: &Event, cache: &mut Cache) {}
71 #[cfg(feature = "rayon")]
73 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
74 dataset
75 .events
76 .par_iter()
77 .zip(resources.caches.par_iter_mut())
78 .for_each(|(event, cache)| {
79 self.precompute(event, cache);
80 })
81 }
82 #[cfg(not(feature = "rayon"))]
84 fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
85 dataset
86 .events
87 .iter()
88 .zip(resources.caches.iter_mut())
89 .for_each(|(event, cache)| self.precompute(event, cache))
90 }
91 fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
100
101 fn compute_gradient(
117 &self,
118 parameters: &Parameters,
119 event: &Event,
120 cache: &Cache,
121 gradient: &mut DVector<Complex<Float>>,
122 ) {
123 self.central_difference_with_indices(
124 &Vec::from_iter(0..parameters.len()),
125 parameters,
126 event,
127 cache,
128 gradient,
129 )
130 }
131
132 fn central_difference_with_indices(
138 &self,
139 indices: &[usize],
140 parameters: &Parameters,
141 event: &Event,
142 cache: &Cache,
143 gradient: &mut DVector<Complex<Float>>,
144 ) {
145 let x = parameters.parameters.to_owned();
146 let constants = parameters.constants.to_owned();
147 let h: DVector<Float> = x
148 .iter()
149 .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
150 .collect::<Vec<_>>()
151 .into();
152 for i in indices {
153 let mut x_plus = x.clone();
154 let mut x_minus = x.clone();
155 x_plus[*i] += h[*i];
156 x_minus[*i] -= h[*i];
157 let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
158 let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
159 gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
160 }
161 }
162}
163
164pub fn central_difference<F: Fn(&[Float]) -> Float>(
166 parameters: &[Float],
167 func: F,
168) -> DVector<Float> {
169 let mut gradient = DVector::zeros(parameters.len());
170 let x = parameters.to_owned();
171 let h: DVector<Float> = x
172 .iter()
173 .map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
174 .collect::<Vec<_>>()
175 .into();
176 for i in 0..parameters.len() {
177 let mut x_plus = x.clone();
178 let mut x_minus = x.clone();
179 x_plus[i] += h[i];
180 x_minus[i] -= h[i];
181 let f_plus = func(&x_plus);
182 let f_minus = func(&x_minus);
183 gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
184 }
185 gradient
186}
187
188dyn_clone::clone_trait_object!(Amplitude);
189
190#[derive(Debug)]
192pub struct AmplitudeValues(pub Vec<Complex<Float>>);
193
194#[derive(Debug)]
196pub struct GradientValues(pub Vec<DVector<Complex<Float>>>);
197
198#[derive(Clone, Default, Debug, Serialize, Deserialize)]
201pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
202
203impl Display for AmplitudeID {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 write!(f, "{}", self.0)
206 }
207}
208
209impl From<AmplitudeID> for Expression {
210 fn from(value: AmplitudeID) -> Self {
211 Self::Amp(value)
212 }
213}
214
215#[derive(Clone, Serialize, Deserialize, Default)]
217pub enum Expression {
218 #[default]
219 Zero,
221 Amp(AmplitudeID),
223 Add(Box<Expression>, Box<Expression>),
225 Mul(Box<Expression>, Box<Expression>),
227 Real(Box<Expression>),
229 Imag(Box<Expression>),
231 NormSqr(Box<Expression>),
233}
234
235impl Debug for Expression {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 self.write_tree(f, "", "", "")
238 }
239}
240
241impl Display for Expression {
242 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243 write!(f, "{:?}", self)
244 }
245}
246
247impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression { Expression::Add(Box::new(a.clone()), Box::new(b.clone()))});
248impl_op_ex!(*|a: &Expression, b: &Expression| -> Expression {
249 Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
250});
251impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))});
252impl_op_ex_commutative!(*|a: &AmplitudeID, b: &Expression| -> Expression {
253 Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
254});
255impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(Expression::Amp(b.clone())))});
256impl_op_ex!(*|a: &AmplitudeID, b: &AmplitudeID| -> Expression {
257 Expression::Mul(
258 Box::new(Expression::Amp(a.clone())),
259 Box::new(Expression::Amp(b.clone())),
260 )
261});
262
263impl AmplitudeID {
264 pub fn real(&self) -> Expression {
266 Expression::Real(Box::new(Expression::Amp(self.clone())))
267 }
268 pub fn imag(&self) -> Expression {
270 Expression::Imag(Box::new(Expression::Amp(self.clone())))
271 }
272 pub fn norm_sqr(&self) -> Expression {
274 Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
275 }
276}
277
278impl Expression {
279 pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
284 match self {
285 Expression::Amp(aid) => amplitude_values.0[aid.1],
286 Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
287 Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
288 Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
289 Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
290 Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
291 Expression::Zero => Complex::ZERO,
292 }
293 }
294 pub fn evaluate_gradient(
299 &self,
300 amplitude_values: &AmplitudeValues,
301 gradient_values: &GradientValues,
302 ) -> DVector<Complex<Float>> {
303 match self {
304 Expression::Amp(aid) => gradient_values.0[aid.1].clone(),
305 Expression::Add(a, b) => {
306 a.evaluate_gradient(amplitude_values, gradient_values)
307 + b.evaluate_gradient(amplitude_values, gradient_values)
308 }
309 Expression::Mul(a, b) => {
310 let f_a = a.evaluate(amplitude_values);
311 let f_b = b.evaluate(amplitude_values);
312 b.evaluate_gradient(amplitude_values, gradient_values)
313 .map(|g| g * f_a)
314 + a.evaluate_gradient(amplitude_values, gradient_values)
315 .map(|g| g * f_b)
316 }
317 Expression::Real(a) => a
318 .evaluate_gradient(amplitude_values, gradient_values)
319 .map(|g| Complex::new(g.re, 0.0)),
320 Expression::Imag(a) => a
321 .evaluate_gradient(amplitude_values, gradient_values)
322 .map(|g| Complex::new(g.im, 0.0)),
323 Expression::NormSqr(a) => {
324 let conj_f_a = a.evaluate(amplitude_values).conjugate();
325 a.evaluate_gradient(amplitude_values, gradient_values)
326 .map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
327 }
328 Expression::Zero => DVector::zeros(0),
329 }
330 }
331 pub fn real(&self) -> Self {
333 Self::Real(Box::new(self.clone()))
334 }
335 pub fn imag(&self) -> Self {
337 Self::Imag(Box::new(self.clone()))
338 }
339 pub fn norm_sqr(&self) -> Self {
341 Self::NormSqr(Box::new(self.clone()))
342 }
343
344 fn write_tree(
346 &self,
347 f: &mut std::fmt::Formatter<'_>,
348 parent_prefix: &str,
349 immediate_prefix: &str,
350 parent_suffix: &str,
351 ) -> std::fmt::Result {
352 let display_string = match self {
353 Self::Amp(aid) => aid.0.clone(),
354 Self::Add(_, _) => "+".to_string(),
355 Self::Mul(_, _) => "*".to_string(),
356 Self::Real(_) => "Re".to_string(),
357 Self::Imag(_) => "Im".to_string(),
358 Self::NormSqr(_) => "NormSqr".to_string(),
359 Self::Zero => "0".to_string(),
360 };
361 writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
362 match self {
363 Self::Amp(_) | Self::Zero => {}
364 Self::Add(a, b) | Self::Mul(a, b) => {
365 let terms = [a, b];
366 let mut it = terms.iter().peekable();
367 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
368 while let Some(child) = it.next() {
369 match it.peek() {
370 Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│ "),
371 None => child.write_tree(f, &child_prefix, "└─ ", " "),
372 }?;
373 }
374 }
375 Self::Real(a) | Self::Imag(a) | Self::NormSqr(a) => {
376 let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
377 a.write_tree(f, &child_prefix, "└─ ", " ")?;
378 }
379 }
380 Ok(())
381 }
382}
383
384#[derive(Default, Clone, Serialize, Deserialize)]
387pub struct Manager {
388 amplitudes: Vec<Box<dyn Amplitude>>,
389 resources: Resources,
390}
391
392impl Manager {
393 pub fn parameters(&self) -> Vec<String> {
395 self.resources.parameters.iter().cloned().collect()
396 }
397 pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
405 let mut amp = amplitude.clone();
406 let aid = amp.register(&mut self.resources)?;
407 self.amplitudes.push(amp);
408 Ok(aid)
409 }
410 pub fn model(&self, expression: &Expression) -> Model {
412 Model {
413 manager: self.clone(),
414 expression: expression.clone(),
415 }
416 }
417}
418
419#[derive(Clone, Serialize, Deserialize)]
425pub struct Model {
426 pub(crate) manager: Manager,
427 pub(crate) expression: Expression,
428}
429
430impl Model {
431 pub fn parameters(&self) -> Vec<String> {
433 self.manager.parameters()
434 }
435 pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
439 let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
440 loaded_resources.write().reserve_cache(dataset.n_events());
441 for amplitude in &self.manager.amplitudes {
442 amplitude.precompute_all(dataset, &mut loaded_resources.write());
443 }
444 Evaluator {
445 amplitudes: self.manager.amplitudes.clone(),
446 resources: loaded_resources.clone(),
447 dataset: dataset.clone(),
448 expression: self.expression.clone(),
449 }
450 }
451}
452
453#[derive(Clone)]
457pub struct Evaluator {
458 pub amplitudes: Vec<Box<dyn Amplitude>>,
462 pub resources: Arc<RwLock<Resources>>,
464 pub dataset: Arc<Dataset>,
466 pub expression: Expression,
468}
469
470impl Evaluator {
471 pub fn parameters(&self) -> Vec<String> {
474 self.resources.read().parameters.iter().cloned().collect()
475 }
476 pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
478 self.resources.write().activate(name)
479 }
480 pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
482 self.resources.write().activate_many(names)
483 }
484 pub fn activate_all(&self) {
486 self.resources.write().activate_all();
487 }
488 pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
490 self.resources.write().deactivate(name)
491 }
492 pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
494 self.resources.write().deactivate_many(names)
495 }
496 pub fn deactivate_all(&self) {
498 self.resources.write().deactivate_all();
499 }
500 pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
502 self.resources.write().isolate(name)
503 }
504 pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
506 self.resources.write().isolate_many(names)
507 }
508
509 pub fn evaluate_local(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
517 let resources = self.resources.read();
518 let parameters = Parameters::new(parameters, &resources.constants);
519 #[cfg(feature = "rayon")]
520 {
521 let amplitude_values_vec: Vec<AmplitudeValues> = self
522 .dataset
523 .events
524 .par_iter()
525 .zip(resources.caches.par_iter())
526 .map(|(event, cache)| {
527 AmplitudeValues(
528 self.amplitudes
529 .iter()
530 .zip(resources.active.iter())
531 .map(|(amp, active)| {
532 if *active {
533 amp.compute(¶meters, event, cache)
534 } else {
535 Complex::new(0.0, 0.0)
536 }
537 })
538 .collect(),
539 )
540 })
541 .collect();
542 amplitude_values_vec
543 .par_iter()
544 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
545 .collect()
546 }
547 #[cfg(not(feature = "rayon"))]
548 {
549 let amplitude_values_vec: Vec<AmplitudeValues> = self
550 .dataset
551 .events
552 .iter()
553 .zip(resources.caches.iter())
554 .map(|(event, cache)| {
555 AmplitudeValues(
556 self.amplitudes
557 .iter()
558 .zip(resources.active.iter())
559 .map(|(amp, active)| {
560 if *active {
561 amp.compute(¶meters, event, cache)
562 } else {
563 Complex::new(0.0, 0.0)
564 }
565 })
566 .collect(),
567 )
568 })
569 .collect();
570 amplitude_values_vec
571 .iter()
572 .map(|amplitude_values| self.expression.evaluate(amplitude_values))
573 .collect()
574 }
575 }
576
577 #[cfg(feature = "mpi")]
585 fn evaluate_mpi(
586 &self,
587 parameters: &[Float],
588 world: &SimpleCommunicator,
589 ) -> Vec<Complex<Float>> {
590 let local_evaluation = self.evaluate_local(parameters);
591 let n_events = self.dataset.n_events();
592 let mut buffer: Vec<Complex<Float>> = vec![Complex::ZERO; n_events];
593 let (counts, displs) = world.get_counts_displs(n_events);
594 {
595 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
596 world.all_gather_varcount_into(&local_evaluation, &mut partitioned_buffer);
597 }
598 buffer
599 }
600
601 pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
604 #[cfg(feature = "mpi")]
605 {
606 if let Some(world) = crate::mpi::get_world() {
607 return self.evaluate_mpi(parameters, &world);
608 }
609 }
610 self.evaluate_local(parameters)
611 }
612
613 pub fn evaluate_gradient_local(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
621 let resources = self.resources.read();
622 let parameters = Parameters::new(parameters, &resources.constants);
623 #[cfg(feature = "rayon")]
624 {
625 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
626 .dataset
627 .events
628 .par_iter()
629 .zip(resources.caches.par_iter())
630 .map(|(event, cache)| {
631 let mut gradient_values =
632 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
633 self.amplitudes
634 .iter()
635 .zip(resources.active.iter())
636 .zip(gradient_values.iter_mut())
637 .for_each(|((amp, active), grad)| {
638 if *active {
639 amp.compute_gradient(¶meters, event, cache, grad)
640 }
641 });
642 (
643 AmplitudeValues(
644 self.amplitudes
645 .iter()
646 .zip(resources.active.iter())
647 .map(|(amp, active)| {
648 if *active {
649 amp.compute(¶meters, event, cache)
650 } else {
651 Complex::new(0.0, 0.0)
652 }
653 })
654 .collect(),
655 ),
656 GradientValues(gradient_values),
657 )
658 })
659 .collect();
660 amplitude_values_and_gradient_vec
661 .par_iter()
662 .map(|(amplitude_values, gradient_values)| {
663 self.expression
664 .evaluate_gradient(amplitude_values, gradient_values)
665 })
666 .collect()
667 }
668 #[cfg(not(feature = "rayon"))]
669 {
670 let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
671 .dataset
672 .events
673 .iter()
674 .zip(resources.caches.iter())
675 .map(|(event, cache)| {
676 let mut gradient_values =
677 vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
678 self.amplitudes
679 .iter()
680 .zip(resources.active.iter())
681 .zip(gradient_values.iter_mut())
682 .for_each(|((amp, active), grad)| {
683 if *active {
684 amp.compute_gradient(¶meters, event, cache, grad)
685 }
686 });
687 (
688 AmplitudeValues(
689 self.amplitudes
690 .iter()
691 .zip(resources.active.iter())
692 .map(|(amp, active)| {
693 if *active {
694 amp.compute(¶meters, event, cache)
695 } else {
696 Complex::new(0.0, 0.0)
697 }
698 })
699 .collect(),
700 ),
701 GradientValues(gradient_values),
702 )
703 })
704 .collect();
705
706 amplitude_values_and_gradient_vec
707 .iter()
708 .map(|(amplitude_values, gradient_values)| {
709 self.expression
710 .evaluate_gradient(amplitude_values, gradient_values)
711 })
712 .collect()
713 }
714 }
715
716 #[cfg(feature = "mpi")]
724 fn evaluate_gradient_mpi(
725 &self,
726 parameters: &[Float],
727 world: &SimpleCommunicator,
728 ) -> Vec<DVector<Complex<Float>>> {
729 let flattened_local_evaluation = self
730 .evaluate_gradient_local(parameters)
731 .iter()
732 .flat_map(|g| g.data.as_vec().to_vec())
733 .collect::<Vec<Complex<Float>>>();
734 let n_events = self.dataset.n_events();
735 let (counts, displs) = world.get_flattened_counts_displs(n_events, parameters.len());
736 let mut flattened_result_buffer = vec![Complex::ZERO; n_events * parameters.len()];
737 let mut partitioned_flattened_result_buffer =
738 PartitionMut::new(&mut flattened_result_buffer, counts, displs);
739 world.all_gather_varcount_into(
740 &flattened_local_evaluation,
741 &mut partitioned_flattened_result_buffer,
742 );
743 flattened_result_buffer
744 .chunks(parameters.len())
745 .map(DVector::from_row_slice)
746 .collect()
747 }
748
749 pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
752 #[cfg(feature = "mpi")]
753 {
754 if let Some(world) = crate::mpi::get_world() {
755 return self.evaluate_gradient_mpi(parameters, &world);
756 }
757 }
758 self.evaluate_gradient_local(parameters)
759 }
760}
761
762#[cfg(test)]
763mod tests {
764 use crate::data::{test_dataset, test_event};
765
766 use super::*;
767 use crate::{
768 data::Event,
769 resources::{Cache, ParameterID, Parameters, Resources},
770 Complex, DVector, Float, LadduError,
771 };
772 use approx::assert_relative_eq;
773 use serde::{Deserialize, Serialize};
774
775 #[derive(Clone, Serialize, Deserialize)]
776 pub struct ComplexScalar {
777 name: String,
778 re: ParameterLike,
779 pid_re: ParameterID,
780 im: ParameterLike,
781 pid_im: ParameterID,
782 }
783
784 impl ComplexScalar {
785 pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
786 Self {
787 name: name.to_string(),
788 re,
789 pid_re: Default::default(),
790 im,
791 pid_im: Default::default(),
792 }
793 .into()
794 }
795 }
796
797 #[typetag::serde]
798 impl Amplitude for ComplexScalar {
799 fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
800 self.pid_re = resources.register_parameter(&self.re);
801 self.pid_im = resources.register_parameter(&self.im);
802 resources.register_amplitude(&self.name)
803 }
804
805 fn compute(
806 &self,
807 parameters: &Parameters,
808 _event: &Event,
809 _cache: &Cache,
810 ) -> Complex<Float> {
811 Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
812 }
813
814 fn compute_gradient(
815 &self,
816 _parameters: &Parameters,
817 _event: &Event,
818 _cache: &Cache,
819 gradient: &mut DVector<Complex<Float>>,
820 ) {
821 if let ParameterID::Parameter(ind) = self.pid_re {
822 gradient[ind] = Complex::ONE;
823 }
824 if let ParameterID::Parameter(ind) = self.pid_im {
825 gradient[ind] = Complex::I;
826 }
827 }
828 }
829
830 #[test]
831 fn test_constant_amplitude() {
832 let mut manager = Manager::default();
833 let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
834 let aid = manager.register(amp).unwrap();
835 let dataset = Arc::new(Dataset {
836 events: vec![Arc::new(test_event())],
837 });
838 let expr = Expression::Amp(aid);
839 let model = manager.model(&expr);
840 let evaluator = model.load(&dataset);
841 let result = evaluator.evaluate(&[]);
842 assert_eq!(result[0], Complex::new(2.0, 3.0));
843 }
844
845 #[test]
846 fn test_parametric_amplitude() {
847 let mut manager = Manager::default();
848 let amp = ComplexScalar::new(
849 "parametric",
850 parameter("test_param_re"),
851 parameter("test_param_im"),
852 );
853 let aid = manager.register(amp).unwrap();
854 let dataset = Arc::new(test_dataset());
855 let expr = Expression::Amp(aid);
856 let model = manager.model(&expr);
857 let evaluator = model.load(&dataset);
858 let result = evaluator.evaluate(&[2.0, 3.0]);
859 assert_eq!(result[0], Complex::new(2.0, 3.0));
860 }
861
862 #[test]
863 fn test_expression_operations() {
864 let mut manager = Manager::default();
865 let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
866 let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
867 let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
868
869 let aid1 = manager.register(amp1).unwrap();
870 let aid2 = manager.register(amp2).unwrap();
871 let aid3 = manager.register(amp3).unwrap();
872
873 let dataset = Arc::new(test_dataset());
874
875 let expr_add = &aid1 + &aid2;
877 let model_add = manager.model(&expr_add);
878 let eval_add = model_add.load(&dataset);
879 let result_add = eval_add.evaluate(&[]);
880 assert_eq!(result_add[0], Complex::new(2.0, 1.0));
881
882 let expr_mul = &aid1 * &aid2;
884 let model_mul = manager.model(&expr_mul);
885 let eval_mul = model_mul.load(&dataset);
886 let result_mul = eval_mul.evaluate(&[]);
887 assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
888
889 let expr_add2 = &expr_add + &expr_mul;
891 let model_add2 = manager.model(&expr_add2);
892 let eval_add2 = model_add2.load(&dataset);
893 let result_add2 = eval_add2.evaluate(&[]);
894 assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
895
896 let expr_mul2 = &expr_add * &expr_mul;
898 let model_mul2 = manager.model(&expr_mul2);
899 let eval_mul2 = model_mul2.load(&dataset);
900 let result_mul2 = eval_mul2.evaluate(&[]);
901 assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
902
903 let expr_real = aid3.real();
905 let model_real = manager.model(&expr_real);
906 let eval_real = model_real.load(&dataset);
907 let result_real = eval_real.evaluate(&[]);
908 assert_eq!(result_real[0], Complex::new(3.0, 0.0));
909
910 let expr_mul2_real = expr_mul2.real();
912 let model_mul2_real = manager.model(&expr_mul2_real);
913 let eval_mul2_real = model_mul2_real.load(&dataset);
914 let result_mul2_real = eval_mul2_real.evaluate(&[]);
915 assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
916
917 let expr_mul2_imag = expr_mul2.imag();
919 let model_mul2_imag = manager.model(&expr_mul2_imag);
920 let eval_mul2_imag = model_mul2_imag.load(&dataset);
921 let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
922 assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
923
924 let expr_imag = aid3.imag();
926 let model_imag = manager.model(&expr_imag);
927 let eval_imag = model_imag.load(&dataset);
928 let result_imag = eval_imag.evaluate(&[]);
929 assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
930
931 let expr_norm = aid1.norm_sqr();
933 let model_norm = manager.model(&expr_norm);
934 let eval_norm = model_norm.load(&dataset);
935 let result_norm = eval_norm.evaluate(&[]);
936 assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
937
938 let expr_mul2_norm = expr_mul2.norm_sqr();
940 let model_mul2_norm = manager.model(&expr_mul2_norm);
941 let eval_mul2_norm = model_mul2_norm.load(&dataset);
942 let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
943 assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
944 }
945
946 #[test]
947 fn test_amplitude_activation() {
948 let mut manager = Manager::default();
949 let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
950 let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
951
952 let aid1 = manager.register(amp1).unwrap();
953 let aid2 = manager.register(amp2).unwrap();
954
955 let dataset = Arc::new(test_dataset());
956 let expr = &aid1 + &aid2;
957 let model = manager.model(&expr);
958 let evaluator = model.load(&dataset);
959
960 let result = evaluator.evaluate(&[]);
962 assert_eq!(result[0], Complex::new(3.0, 0.0));
963
964 evaluator.deactivate("const1").unwrap();
966 let result = evaluator.evaluate(&[]);
967 assert_eq!(result[0], Complex::new(2.0, 0.0));
968
969 evaluator.isolate("const1").unwrap();
971 let result = evaluator.evaluate(&[]);
972 assert_eq!(result[0], Complex::new(1.0, 0.0));
973
974 evaluator.activate_all();
976 let result = evaluator.evaluate(&[]);
977 assert_eq!(result[0], Complex::new(3.0, 0.0));
978 }
979
980 #[test]
981 fn test_gradient() {
982 let mut manager = Manager::default();
983 let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
984
985 let aid = manager.register(amp).unwrap();
986 let dataset = Arc::new(test_dataset());
987 let expr = aid.norm_sqr();
988 let model = manager.model(&expr);
989 let evaluator = model.load(&dataset);
990
991 let params = vec![2.0];
992 let gradient = evaluator.evaluate_gradient(¶ms);
993
994 assert_relative_eq!(gradient[0][0].re, 4.0);
996 assert_relative_eq!(gradient[0][0].im, 0.0);
997 }
998
999 #[test]
1000 fn test_parameter_registration() {
1001 let mut manager = Manager::default();
1002 let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
1003
1004 let aid = manager.register(amp).unwrap();
1005 let parameters = manager.parameters();
1006 let model = manager.model(&aid.into());
1007 let model_parameters = model.parameters();
1008 assert_eq!(parameters.len(), 1);
1009 assert_eq!(parameters[0], "test_param_re");
1010 assert_eq!(model_parameters.len(), 1);
1011 assert_eq!(model_parameters[0], "test_param_re");
1012 }
1013
1014 #[test]
1015 fn test_duplicate_amplitude_registration() {
1016 let mut manager = Manager::default();
1017 let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
1018 let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
1019 manager.register(amp1).unwrap();
1020 assert!(manager.register(amp2).is_err());
1021 }
1022}