use laddu_core::{
amplitudes::{Amplitude, AmplitudeID, ParameterLike},
data::Event,
resources::{Cache, ParameterID, Parameters, Resources},
Complex, DVector, Float, LadduError,
};
use serde::{Deserialize, Serialize};
#[cfg(feature = "python")]
use laddu_python::amplitudes::{PyAmplitude, PyParameterLike};
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[derive(Clone, Serialize, Deserialize)]
pub struct Scalar {
name: String,
value: ParameterLike,
pid: ParameterID,
}
impl Scalar {
pub fn new(name: &str, value: ParameterLike) -> Box<Self> {
Self {
name: name.to_string(),
value,
pid: Default::default(),
}
.into()
}
}
#[typetag::serde]
impl Amplitude for Scalar {
fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
self.pid = resources.register_parameter(&self.value);
resources.register_amplitude(&self.name)
}
fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
Complex::new(parameters.get(self.pid), 0.0)
}
fn compute_gradient(
&self,
_parameters: &Parameters,
_event: &Event,
_cache: &Cache,
gradient: &mut DVector<Complex<Float>>,
) {
if let ParameterID::Parameter(ind) = self.pid {
gradient[ind] = Complex::ONE;
}
}
}
#[cfg(feature = "python")]
#[pyfunction(name = "Scalar")]
pub fn py_scalar(name: &str, value: PyParameterLike) -> PyAmplitude {
PyAmplitude(Scalar::new(name, value.0))
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ComplexScalar {
name: String,
re: ParameterLike,
pid_re: ParameterID,
im: ParameterLike,
pid_im: ParameterID,
}
impl ComplexScalar {
pub fn new(name: &str, re: ParameterLike, im: ParameterLike) -> Box<Self> {
Self {
name: name.to_string(),
re,
pid_re: Default::default(),
im,
pid_im: Default::default(),
}
.into()
}
}
#[typetag::serde]
impl Amplitude for ComplexScalar {
fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
self.pid_re = resources.register_parameter(&self.re);
self.pid_im = resources.register_parameter(&self.im);
resources.register_amplitude(&self.name)
}
fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
Complex::new(parameters.get(self.pid_re), parameters.get(self.pid_im))
}
fn compute_gradient(
&self,
_parameters: &Parameters,
_event: &Event,
_cache: &Cache,
gradient: &mut DVector<Complex<Float>>,
) {
if let ParameterID::Parameter(ind) = self.pid_re {
gradient[ind] = Complex::ONE;
}
if let ParameterID::Parameter(ind) = self.pid_im {
gradient[ind] = Complex::I;
}
}
}
#[cfg(feature = "python")]
#[pyfunction(name = "ComplexScalar")]
pub fn py_complex_scalar(name: &str, re: PyParameterLike, im: PyParameterLike) -> PyAmplitude {
PyAmplitude(ComplexScalar::new(name, re.0, im.0))
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PolarComplexScalar {
name: String,
r: ParameterLike,
pid_r: ParameterID,
theta: ParameterLike,
pid_theta: ParameterID,
}
impl PolarComplexScalar {
pub fn new(name: &str, r: ParameterLike, theta: ParameterLike) -> Box<Self> {
Self {
name: name.to_string(),
r,
pid_r: Default::default(),
theta,
pid_theta: Default::default(),
}
.into()
}
}
#[typetag::serde]
impl Amplitude for PolarComplexScalar {
fn register(&mut self, resources: &mut Resources) -> Result<AmplitudeID, LadduError> {
self.pid_r = resources.register_parameter(&self.r);
self.pid_theta = resources.register_parameter(&self.theta);
resources.register_amplitude(&self.name)
}
fn compute(&self, parameters: &Parameters, _event: &Event, _cache: &Cache) -> Complex<Float> {
Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta))
}
fn compute_gradient(
&self,
parameters: &Parameters,
_event: &Event,
_cache: &Cache,
gradient: &mut DVector<Complex<Float>>,
) {
let exp_i_theta = Complex::cis(parameters.get(self.pid_theta));
if let ParameterID::Parameter(ind) = self.pid_r {
gradient[ind] = exp_i_theta;
}
if let ParameterID::Parameter(ind) = self.pid_theta {
gradient[ind] = Complex::<Float>::I
* Complex::from_polar(parameters.get(self.pid_r), parameters.get(self.pid_theta));
}
}
}
#[cfg(feature = "python")]
#[pyfunction(name = "PolarComplexScalar")]
pub fn py_polar_complex_scalar(
name: &str,
r: PyParameterLike,
theta: PyParameterLike,
) -> PyAmplitude {
PyAmplitude(PolarComplexScalar::new(name, r.0, theta.0))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use laddu_core::{data::test_dataset, parameter, Manager, PI};
use std::sync::Arc;
#[test]
fn test_scalar_creation_and_evaluation() {
let mut manager = Manager::default();
let amp = Scalar::new("test_scalar", parameter("test_param"));
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(test_dataset());
let expr = aid.into(); let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let params = vec![2.5];
let result = evaluator.evaluate(¶ms);
assert_relative_eq!(result[0].re, 2.5);
assert_relative_eq!(result[0].im, 0.0);
}
#[test]
fn test_scalar_gradient() {
let mut manager = Manager::default();
let amp = Scalar::new("test_scalar", parameter("test_param"));
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(test_dataset());
let expr = aid.norm_sqr(); let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let params = vec![2.0];
let gradient = evaluator.evaluate_gradient(¶ms);
assert_relative_eq!(gradient[0][0].re, 4.0);
assert_relative_eq!(gradient[0][0].im, 0.0);
}
#[test]
fn test_complex_scalar_evaluation() {
let mut manager = Manager::default();
let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(test_dataset());
let expr = aid.into();
let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let params = vec![1.5, 2.5]; let result = evaluator.evaluate(¶ms);
assert_relative_eq!(result[0].re, 1.5);
assert_relative_eq!(result[0].im, 2.5);
}
#[test]
fn test_complex_scalar_gradient() {
let mut manager = Manager::default();
let amp = ComplexScalar::new("test_complex", parameter("re_param"), parameter("im_param"));
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(test_dataset());
let expr = aid.norm_sqr(); let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let params = vec![3.0, 4.0]; let gradient = evaluator.evaluate_gradient(¶ms);
assert_relative_eq!(gradient[0][0].re, 6.0);
assert_relative_eq!(gradient[0][0].im, 0.0);
assert_relative_eq!(gradient[0][1].re, 8.0);
assert_relative_eq!(gradient[0][1].im, 0.0);
}
#[test]
fn test_polar_complex_scalar_evaluation() {
let mut manager = Manager::default();
let amp =
PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(test_dataset());
let expr = aid.into();
let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let r = 2.0;
let theta = PI / 4.3;
let params = vec![r, theta];
let result = evaluator.evaluate(¶ms);
assert_relative_eq!(result[0].re, r * theta.cos());
assert_relative_eq!(result[0].im, r * theta.sin());
}
#[test]
fn test_polar_complex_scalar_gradient() {
let mut manager = Manager::default();
let amp =
PolarComplexScalar::new("test_polar", parameter("r_param"), parameter("theta_param"));
let aid = manager.register(amp).unwrap();
let dataset = Arc::new(test_dataset());
let expr = aid.into(); let model = manager.model(&expr);
let evaluator = model.load(&dataset);
let r = 2.0;
let theta = PI / 4.3;
let params = vec![r, theta];
let gradient = evaluator.evaluate_gradient(¶ms);
assert_relative_eq!(gradient[0][0].re, Float::cos(theta));
assert_relative_eq!(gradient[0][0].im, Float::sin(theta));
assert_relative_eq!(gradient[0][1].re, -r * Float::sin(theta));
assert_relative_eq!(gradient[0][1].im, r * Float::cos(theta));
}
}