use dyn_clone::DynClone;
use nalgebra::DVector;
use num::complex::Complex64;
use crate::{
amplitude::{AmplitudeID, AmplitudeSemanticKey},
data::{Dataset, DatasetMetadata, Event},
expression::{Expression, ExpressionDependence},
resources::{Cache, Parameters, Resources},
LadduResult,
};
#[typetag::serde(tag = "type")]
pub trait Amplitude: DynClone + Send + Sync {
fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID>;
fn semantic_key(&self) -> Option<AmplitudeSemanticKey> {
None
}
fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
Ok(())
}
fn dependence_hint(&self) -> ExpressionDependence {
ExpressionDependence::Mixed
}
fn real_valued_hint(&self) -> bool {
false
}
#[allow(unused_variables)]
fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {}
#[cfg(feature = "rayon")]
fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
use rayon::prelude::*;
resources
.caches
.par_iter_mut()
.enumerate()
.for_each(|(event_index, cache)| {
let event = dataset.event_view(event_index);
self.precompute(&event, cache);
});
}
#[cfg(not(feature = "rayon"))]
fn precompute_all(&self, dataset: &Dataset, resources: &mut Resources) {
resources
.caches
.iter_mut()
.enumerate()
.for_each(|(event_index, cache)| {
let event = dataset.event_view(event_index);
self.precompute(&event, cache);
});
}
fn compute(&self, parameters: &Parameters, cache: &Cache) -> Complex64;
fn compute_gradient(
&self,
parameters: &Parameters,
cache: &Cache,
gradient: &mut DVector<Complex64>,
) {
self.central_difference_with_indices(
&Vec::from_iter(0..parameters.len()),
parameters,
cache,
gradient,
)
}
fn central_difference_with_indices(
&self,
indices: &[usize],
parameters: &Parameters,
cache: &Cache,
gradient: &mut DVector<Complex64>,
) {
let x = parameters.values().to_owned();
let h: DVector<f64> = x
.iter()
.map(|&xi| f64::cbrt(f64::EPSILON) * (xi.abs() + 1.0))
.collect::<Vec<_>>()
.into();
for i in indices {
let mut x_plus = x.clone();
let mut x_minus = x.clone();
x_plus[*i] += h[*i];
x_minus[*i] -= h[*i];
let f_plus = self.compute(¶meters.with_values(x_plus), cache);
let f_minus = self.compute(¶meters.with_values(x_minus), cache);
gradient[*i] = (f_plus - f_minus) / (2.0 * h[*i]);
}
}
fn into_expression(self) -> LadduResult<Expression>
where
Self: Sized + 'static,
{
Expression::from_amplitude(Box::new(self))
}
}
dyn_clone::clone_trait_object!(Amplitude);