use std::{
fmt::{Debug, Display},
sync::Arc,
};
use auto_ops::*;
use dyn_clone::DynClone;
use nalgebra::{ComplexField, DVector};
use num::Complex;
use parking_lot::RwLock;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::{
data::{Dataset, Event},
resources::{Cache, Parameters, Resources},
Float, LadduError,
};
#[derive(Clone, Default, Serialize, Deserialize)]
pub enum ParameterLike {
Parameter(String),
Constant(Float),
#[default]
Uninit,
}
pub fn parameter(name: &str) -> ParameterLike {
ParameterLike::Parameter(name.to_string())
}
pub fn constant(value: Float) -> ParameterLike {
ParameterLike::Constant(value)
}
#[typetag::serde(tag = "type")]
pub trait Amplitude: DynClone + Send + Sync {
fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError>;
#[allow(unused_variables)]
fn precompute(&self, event: &Event, cache: &mut Cache) {}
#[cfg(feature = "rayon")]
fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
dataset
.events
.par_iter()
.zip(resources.caches.par_iter_mut())
.for_each(|(event, cache)| {
self.precompute(event, cache);
})
}
#[cfg(not(feature = "rayon"))]
fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
dataset
.events
.iter()
.zip(resources.caches.iter_mut())
.for_each(|(event, cache)| self.precompute(event, cache))
}
fn compute(&self, parameters: &Parameters, event: &Event, cache: &Cache) -> Complex<Float>;
fn compute_gradient(
&self,
parameters: &Parameters,
event: &Event,
cache: &Cache,
gradient: &mut DVector<Complex<Float>>,
) {
self.central_difference_with_indices(
&Vec::from_iter(0..parameters.len()),
parameters,
event,
cache,
gradient,
)
}
fn central_difference_with_indices(
&self,
indices: &[usize],
parameters: &Parameters,
event: &Event,
cache: &Cache,
gradient: &mut DVector<Complex<Float>>,
) {
let x = parameters.parameters.to_owned();
let constants = parameters.constants.to_owned();
let h: DVector<Float> = x
.iter()
.map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
.collect::<Vec<_>>()
.into();
for i in indices {
let mut x_plus = x.clone();
let mut x_minus = x.clone();
x_plus[*i] += h[*i];
x_minus[*i] -= h[*i];
let f_plus = self.compute(&Parameters::new(&x_plus, &constants), event, cache);
let f_minus = self.compute(&Parameters::new(&x_minus, &constants), event, cache);
gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
}
}
}
pub fn central_difference<F: Fn(&[Float]) -> Float>(
parameters: &[Float],
func: F,
) -> DVector<Float> {
let mut gradient = DVector::zeros(parameters.len());
let x = parameters.to_owned();
let h: DVector<Float> = x
.iter()
.map(|&xi| Float::cbrt(Float::EPSILON) * (xi.abs() + 1.0))
.collect::<Vec<_>>()
.into();
for i in 0..parameters.len() {
let mut x_plus = x.clone();
let mut x_minus = x.clone();
x_plus[i] += h[i];
x_minus[i] -= h[i];
let f_plus = func(&x_plus);
let f_minus = func(&x_minus);
gradient[i] = (f_plus - f_minus) / (2.0 * h[i]);
}
gradient
}
dyn_clone::clone_trait_object!(Amplitude);
#[derive(Debug)]
pub struct AmplitudeValues(pub Vec<Complex<Float>>);
#[derive(Debug)]
pub struct GradientValues(pub Vec<DVector<Complex<Float>>>);
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct AmplitudeID(pub(crate) String, pub(crate) usize);
impl Display for AmplitudeID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<AmplitudeID> for Expression {
fn from(value: AmplitudeID) -> Self {
Self::Amp(value)
}
}
#[derive(Clone, Serialize, Deserialize, Default)]
pub enum Expression {
#[default]
Zero,
Amp(AmplitudeID),
Add(Box<Expression>, Box<Expression>),
Mul(Box<Expression>, Box<Expression>),
Real(Box<Expression>),
Imag(Box<Expression>),
NormSqr(Box<Expression>),
}
impl Debug for Expression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.write_tree(f, "", "", "")
}
}
impl Display for Expression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl_op_ex!(+ |a: &Expression, b: &Expression| -> Expression { Expression::Add(Box::new(a.clone()), Box::new(b.clone()))});
impl_op_ex!(*|a: &Expression, b: &Expression| -> Expression {
Expression::Mul(Box::new(a.clone()), Box::new(b.clone()))
});
impl_op_ex_commutative!(+ |a: &AmplitudeID, b: &Expression| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))});
impl_op_ex_commutative!(*|a: &AmplitudeID, b: &Expression| -> Expression {
Expression::Mul(Box::new(Expression::Amp(a.clone())), Box::new(b.clone()))
});
impl_op_ex!(+ |a: &AmplitudeID, b: &AmplitudeID| -> Expression { Expression::Add(Box::new(Expression::Amp(a.clone())), Box::new(Expression::Amp(b.clone())))});
impl_op_ex!(*|a: &AmplitudeID, b: &AmplitudeID| -> Expression {
Expression::Mul(
Box::new(Expression::Amp(a.clone())),
Box::new(Expression::Amp(b.clone())),
)
});
impl AmplitudeID {
pub fn real(&self) -> Expression {
Expression::Real(Box::new(Expression::Amp(self.clone())))
}
pub fn imag(&self) -> Expression {
Expression::Imag(Box::new(Expression::Amp(self.clone())))
}
pub fn norm_sqr(&self) -> Expression {
Expression::NormSqr(Box::new(Expression::Amp(self.clone())))
}
}
impl Expression {
pub fn evaluate(&self, amplitude_values: &AmplitudeValues) -> Complex<Float> {
match self {
Expression::Amp(aid) => amplitude_values.0[aid.1],
Expression::Add(a, b) => a.evaluate(amplitude_values) + b.evaluate(amplitude_values),
Expression::Mul(a, b) => a.evaluate(amplitude_values) * b.evaluate(amplitude_values),
Expression::Real(a) => Complex::new(a.evaluate(amplitude_values).re, 0.0),
Expression::Imag(a) => Complex::new(a.evaluate(amplitude_values).im, 0.0),
Expression::NormSqr(a) => Complex::new(a.evaluate(amplitude_values).norm_sqr(), 0.0),
Expression::Zero => Complex::ZERO,
}
}
pub fn evaluate_gradient(
&self,
amplitude_values: &AmplitudeValues,
gradient_values: &GradientValues,
) -> DVector<Complex<Float>> {
match self {
Expression::Amp(aid) => gradient_values.0[aid.1].clone(),
Expression::Add(a, b) => {
a.evaluate_gradient(amplitude_values, gradient_values)
+ b.evaluate_gradient(amplitude_values, gradient_values)
}
Expression::Mul(a, b) => {
let f_a = a.evaluate(amplitude_values);
let f_b = b.evaluate(amplitude_values);
b.evaluate_gradient(amplitude_values, gradient_values)
.map(|g| g * f_a)
+ a.evaluate_gradient(amplitude_values, gradient_values)
.map(|g| g * f_b)
}
Expression::Real(a) => a
.evaluate_gradient(amplitude_values, gradient_values)
.map(|g| Complex::new(g.re, 0.0)),
Expression::Imag(a) => a
.evaluate_gradient(amplitude_values, gradient_values)
.map(|g| Complex::new(g.im, 0.0)),
Expression::NormSqr(a) => {
let conj_f_a = a.evaluate(amplitude_values).conjugate();
a.evaluate_gradient(amplitude_values, gradient_values)
.map(|g| Complex::new(2.0 * (g * conj_f_a).re, 0.0))
}
Expression::Zero => DVector::zeros(0),
}
}
pub fn real(&self) -> Self {
Self::Real(Box::new(self.clone()))
}
pub fn imag(&self) -> Self {
Self::Imag(Box::new(self.clone()))
}
pub fn norm_sqr(&self) -> Self {
Self::NormSqr(Box::new(self.clone()))
}
fn write_tree(
&self,
f: &mut std::fmt::Formatter<'_>,
parent_prefix: &str,
immediate_prefix: &str,
parent_suffix: &str,
) -> std::fmt::Result {
let display_string = match self {
Self::Amp(aid) => aid.0.clone(),
Self::Add(_, _) => "+".to_string(),
Self::Mul(_, _) => "*".to_string(),
Self::Real(_) => "Re".to_string(),
Self::Imag(_) => "Im".to_string(),
Self::NormSqr(_) => "NormSqr".to_string(),
Self::Zero => "0".to_string(),
};
writeln!(f, "{}{}{}", parent_prefix, immediate_prefix, display_string)?;
match self {
Self::Amp(_) | Self::Zero => {}
Self::Add(a, b) | Self::Mul(a, b) => {
let terms = [a, b];
let mut it = terms.iter().peekable();
let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
while let Some(child) = it.next() {
match it.peek() {
Some(_) => child.write_tree(f, &child_prefix, "├─ ", "│ "),
None => child.write_tree(f, &child_prefix, "└─ ", " "),
}?;
}
}
Self::Real(a) | Self::Imag(a) | Self::NormSqr(a) => {
let child_prefix = format!("{}{}", parent_prefix, parent_suffix);
a.write_tree(f, &child_prefix, "└─ ", " ")?;
}
}
Ok(())
}
}
#[derive(Default, Clone, Serialize, Deserialize)]
pub struct Manager {
amplitudes: Vec<Box<dyn Amplitude>>,
resources: Resources,
}
impl Manager {
pub fn parameters(&self) -> Vec<String> {
self.resources.parameters.iter().cloned().collect()
}
pub fn register(&mut self, amplitude: Box<dyn Amplitude>) -> Result<AmplitudeID, LadduError> {
let mut amp = amplitude.clone();
let aid = amp.register(&mut self.resources)?;
self.amplitudes.push(amp);
Ok(aid)
}
pub fn model(&self, expression: &Expression) -> Model {
Model {
manager: self.clone(),
expression: expression.clone(),
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Model {
pub(crate) manager: Manager,
pub(crate) expression: Expression,
}
impl Model {
pub fn parameters(&self) -> Vec<String> {
self.manager.parameters()
}
pub fn load(&self, dataset: &Arc<Dataset>) -> Evaluator {
let loaded_resources = Arc::new(RwLock::new(self.manager.resources.clone()));
loaded_resources.write().reserve_cache(dataset.len());
for amplitude in &self.manager.amplitudes {
amplitude.precompute_all(dataset, &mut loaded_resources.write());
}
Evaluator {
amplitudes: self.manager.amplitudes.clone(),
resources: loaded_resources.clone(),
dataset: dataset.clone(),
expression: self.expression.clone(),
}
}
}
#[derive(Clone)]
pub struct Evaluator {
pub amplitudes: Vec<Box<dyn Amplitude>>,
pub resources: Arc<RwLock<Resources>>,
pub dataset: Arc<Dataset>,
pub expression: Expression,
}
impl Evaluator {
pub fn parameters(&self) -> Vec<String> {
self.resources.read().parameters.iter().cloned().collect()
}
pub fn activate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
self.resources.write().activate(name)
}
pub fn activate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
self.resources.write().activate_many(names)
}
pub fn activate_all(&self) {
self.resources.write().activate_all();
}
pub fn deactivate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
self.resources.write().deactivate(name)
}
pub fn deactivate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
self.resources.write().deactivate_many(names)
}
pub fn deactivate_all(&self) {
self.resources.write().deactivate_all();
}
pub fn isolate<T: AsRef<str>>(&self, name: T) -> Result<(), LadduError> {
self.resources.write().isolate(name)
}
pub fn isolate_many<T: AsRef<str>>(&self, names: &[T]) -> Result<(), LadduError> {
self.resources.write().isolate_many(names)
}
#[cfg(feature = "rayon")]
pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
let resources = self.resources.read();
let parameters = Parameters::new(parameters, &resources.constants);
let amplitude_values_vec: Vec<AmplitudeValues> = self
.dataset
.events
.par_iter()
.zip(resources.caches.par_iter())
.map(|(event, cache)| {
AmplitudeValues(
self.amplitudes
.iter()
.zip(resources.active.iter())
.map(|(amp, active)| {
if *active {
amp.compute(¶meters, event, cache)
} else {
Complex::new(0.0, 0.0)
}
})
.collect(),
)
})
.collect();
amplitude_values_vec
.par_iter()
.map(|amplitude_values| self.expression.evaluate(amplitude_values))
.collect()
}
#[cfg(not(feature = "rayon"))]
pub fn evaluate(&self, parameters: &[Float]) -> Vec<Complex<Float>> {
let resources = self.resources.read();
let parameters = Parameters::new(parameters, &resources.constants);
let amplitude_values_vec: Vec<AmplitudeValues> = self
.dataset
.events
.iter()
.zip(resources.caches.iter())
.map(|(event, cache)| {
AmplitudeValues(
self.amplitudes
.iter()
.zip(resources.active.iter())
.map(|(amp, active)| {
if *active {
amp.compute(¶meters, event, cache)
} else {
Complex::new(0.0, 0.0)
}
})
.collect(),
)
})
.collect();
amplitude_values_vec
.iter()
.map(|amplitude_values| self.expression.evaluate(amplitude_values))
.collect()
}
#[cfg(feature = "rayon")]
pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
let resources = self.resources.read();
let parameters = Parameters::new(parameters, &resources.constants);
let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
.dataset
.events
.par_iter()
.zip(resources.caches.par_iter())
.map(|(event, cache)| {
let mut gradient_values =
vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
self.amplitudes
.iter()
.zip(resources.active.iter())
.zip(gradient_values.iter_mut())
.for_each(|((amp, active), grad)| {
if *active {
amp.compute_gradient(¶meters, event, cache, grad)
}
});
(
AmplitudeValues(
self.amplitudes
.iter()
.zip(resources.active.iter())
.map(|(amp, active)| {
if *active {
amp.compute(¶meters, event, cache)
} else {
Complex::new(0.0, 0.0)
}
})
.collect(),
),
GradientValues(gradient_values),
)
})
.collect();
amplitude_values_and_gradient_vec
.par_iter()
.map(|(amplitude_values, gradient_values)| {
self.expression
.evaluate_gradient(amplitude_values, gradient_values)
})
.collect()
}
#[cfg(not(feature = "rayon"))]
pub fn evaluate_gradient(&self, parameters: &[Float]) -> Vec<DVector<Complex<Float>>> {
let resources = self.resources.read();
let parameters = Parameters::new(parameters, &resources.constants);
let amplitude_values_and_gradient_vec: Vec<(AmplitudeValues, GradientValues)> = self
.dataset
.events
.iter()
.zip(resources.caches.iter())
.map(|(event, cache)| {
let mut gradient_values =
vec![DVector::zeros(parameters.len()); self.amplitudes.len()];
self.amplitudes
.iter()
.zip(resources.active.iter())
.zip(gradient_values.iter_mut())
.for_each(|((amp, active), grad)| {
if *active {
amp.compute_gradient(¶meters, event, cache, grad)
}
});
(
AmplitudeValues(
self.amplitudes
.iter()
.zip(resources.active.iter())
.map(|(amp, active)| {
if *active {
amp.compute(¶meters, event, cache)
} else {
Complex::new(0.0, 0.0)
}
})
.collect(),
),
GradientValues(gradient_values),
)
})
.collect();
amplitude_values_and_gradient_vec
.iter()
.map(|(amplitude_values, gradient_values)| {
self.expression
.evaluate_gradient(amplitude_values, gradient_values)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use crate::data::{test_dataset, test_event};
use super::*;
use crate::{
data::Event,
resources::{Cache, ParameterID, Parameters, Resources},
Complex, DVector, Float, LadduError,
};
use approx::assert_relative_eq;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct ComplexScalar {
name: String,
re: ParameterLike,
pid_re: ParameterID,
im: ParameterLike,
pid_im: ParameterID,
}
impl ComplexScalar {
pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
Self {
name: name.to_string(),
re,
pid_re: Default::default(),
im,
pid_im: Default::default(),
}
.into()
}
}
#[typetag::serde]
impl Amplitude for ComplexScalar {
fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
self.pid_re = resources.register_parameter(&self.re);
self.pid_im = resources.register_parameter(&self.im);
resources.register_amplitude(&self.name)
}
fn compute(
&self,
parameters: &Parameters,
_event: &Event,
_cache: &Cache,
) -> Complex<Float> {
Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
}
fn compute_gradient(
&self,
_parameters: &Parameters,
_event: &Event,
_cache: &Cache,
gradient: &mut DVector<Complex<Float>>,
) {
if let ParameterID::Parameter(ind) = self.pid_re {
gradient[ind] = Complex::ONE;
}
if let ParameterID::Parameter(ind) = self.pid_im {
gradient[ind] = Complex::I;
}
}
}
#[test]
fn test_constant_amplitude() {
let mut manager = Manager::default();
let amp = ComplexScalar::new("constant", constant(2.0), constant(3.0));
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(Dataset {
events: vec![Arc::new(test_event())],
});
let expr = Expression::Amp(aid);
let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let result = evaluator.evaluate(&[]);
assert_eq!(result[0], Complex::new(2.0, 3.0));
}
#[test]
fn test_parametric_amplitude() {
let mut manager = Manager::default();
let amp = ComplexScalar::new(
"parametric",
parameter("test_param_re"),
parameter("test_param_im"),
);
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(test_dataset());
let expr = Expression::Amp(aid);
let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let result = evaluator.evaluate(&[2.0, 3.0]);
assert_eq!(result[0], Complex::new(2.0, 3.0));
}
#[test]
fn test_expression_operations() {
let mut manager = Manager::default();
let amp1 = ComplexScalar::new("const1", constant(2.0), constant(0.0));
let amp2 = ComplexScalar::new("const2", constant(0.0), constant(1.0));
let amp3 = ComplexScalar::new("const3", constant(3.0), constant(4.0));
let aid1 = manager.register(amp1).unwrap();
let aid2 = manager.register(amp2).unwrap();
let aid3 = manager.register(amp3).unwrap();
let dataset = Arc::new(test_dataset());
let expr_add = &aid1 + &aid2;
let model_add = manager.model(&expr_add);
let eval_add = model_add.load(&dataset);
let result_add = eval_add.evaluate(&[]);
assert_eq!(result_add[0], Complex::new(2.0, 1.0));
let expr_mul = &aid1 * &aid2;
let model_mul = manager.model(&expr_mul);
let eval_mul = model_mul.load(&dataset);
let result_mul = eval_mul.evaluate(&[]);
assert_eq!(result_mul[0], Complex::new(0.0, 2.0));
let expr_add2 = &expr_add + &expr_mul;
let model_add2 = manager.model(&expr_add2);
let eval_add2 = model_add2.load(&dataset);
let result_add2 = eval_add2.evaluate(&[]);
assert_eq!(result_add2[0], Complex::new(2.0, 3.0));
let expr_mul2 = &expr_add * &expr_mul;
let model_mul2 = manager.model(&expr_mul2);
let eval_mul2 = model_mul2.load(&dataset);
let result_mul2 = eval_mul2.evaluate(&[]);
assert_eq!(result_mul2[0], Complex::new(-2.0, 4.0));
let expr_real = aid3.real();
let model_real = manager.model(&expr_real);
let eval_real = model_real.load(&dataset);
let result_real = eval_real.evaluate(&[]);
assert_eq!(result_real[0], Complex::new(3.0, 0.0));
let expr_mul2_real = expr_mul2.real();
let model_mul2_real = manager.model(&expr_mul2_real);
let eval_mul2_real = model_mul2_real.load(&dataset);
let result_mul2_real = eval_mul2_real.evaluate(&[]);
assert_eq!(result_mul2_real[0], Complex::new(-2.0, 0.0));
let expr_mul2_imag = expr_mul2.imag();
let model_mul2_imag = manager.model(&expr_mul2_imag);
let eval_mul2_imag = model_mul2_imag.load(&dataset);
let result_mul2_imag = eval_mul2_imag.evaluate(&[]);
assert_eq!(result_mul2_imag[0], Complex::new(4.0, 0.0));
let expr_imag = aid3.imag();
let model_imag = manager.model(&expr_imag);
let eval_imag = model_imag.load(&dataset);
let result_imag = eval_imag.evaluate(&[]);
assert_eq!(result_imag[0], Complex::new(4.0, 0.0));
let expr_norm = aid1.norm_sqr();
let model_norm = manager.model(&expr_norm);
let eval_norm = model_norm.load(&dataset);
let result_norm = eval_norm.evaluate(&[]);
assert_eq!(result_norm[0], Complex::new(4.0, 0.0));
let expr_mul2_norm = expr_mul2.norm_sqr();
let model_mul2_norm = manager.model(&expr_mul2_norm);
let eval_mul2_norm = model_mul2_norm.load(&dataset);
let result_mul2_norm = eval_mul2_norm.evaluate(&[]);
assert_eq!(result_mul2_norm[0], Complex::new(20.0, 0.0));
}
#[test]
fn test_amplitude_activation() {
let mut manager = Manager::default();
let amp1 = ComplexScalar::new("const1", constant(1.0), constant(0.0));
let amp2 = ComplexScalar::new("const2", constant(2.0), constant(0.0));
let aid1 = manager.register(amp1).unwrap();
let aid2 = manager.register(amp2).unwrap();
let dataset = Arc::new(test_dataset());
let expr = &aid1 + &aid2;
let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let result = evaluator.evaluate(&[]);
assert_eq!(result[0], Complex::new(3.0, 0.0));
evaluator.deactivate("const1").unwrap();
let result = evaluator.evaluate(&[]);
assert_eq!(result[0], Complex::new(2.0, 0.0));
evaluator.isolate("const1").unwrap();
let result = evaluator.evaluate(&[]);
assert_eq!(result[0], Complex::new(1.0, 0.0));
evaluator.activate_all();
let result = evaluator.evaluate(&[]);
assert_eq!(result[0], Complex::new(3.0, 0.0));
}
#[test]
fn test_gradient() {
let mut manager = Manager::default();
let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(test_dataset());
let expr = aid.norm_sqr();
let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let params = vec![2.0];
let gradient = evaluator.evaluate_gradient(¶ms);
assert_relative_eq!(gradient[0][0].re, 4.0);
assert_relative_eq!(gradient[0][0].im, 0.0);
}
#[test]
fn test_parameter_registration() {
let mut manager = Manager::default();
let amp = ComplexScalar::new("parametric", parameter("test_param_re"), constant(2.0));
let aid = manager.register(amp).unwrap();
let parameters = manager.parameters();
let model = manager.model(&aid.into());
let model_parameters = model.parameters();
assert_eq!(parameters.len(), 1);
assert_eq!(parameters[0], "test_param_re");
assert_eq!(model_parameters.len(), 1);
assert_eq!(model_parameters[0], "test_param_re");
}
#[test]
fn test_duplicate_amplitude_registration() {
let mut manager = Manager::default();
let amp1 = ComplexScalar::new("same_name", constant(1.0), constant(0.0));
let amp2 = ComplexScalar::new("same_name", constant(2.0), constant(0.0));
manager.register(amp1).unwrap();
assert!(manager.register(amp2).is_err());
}
}