use dyn_clone::DynClone;
use itertools::Itertools;
use nalgebra::Complex;
use parking_lot::RwLock;
use rayon::prelude::*;
use std::{
collections::HashSet,
fmt::{Debug, Display},
ops::{Add, Mul},
sync::Arc,
};
use tracing::{debug, info};
use crate::{
convert,
dataset::{Dataset, Event},
errors::RustitudeError,
Field,
};
#[derive(Clone)]
pub struct Parameter<F: Field> {
pub amplitude: String,
pub name: String,
pub index: Option<usize>,
pub fixed_index: Option<usize>,
pub initial: F,
pub bounds: (F, F),
}
impl<F: Field> Parameter<F> {
pub fn new(amplitude: &str, name: &str, index: usize) -> Self {
Self {
amplitude: amplitude.to_string(),
name: name.to_string(),
index: Some(index),
fixed_index: None,
initial: F::one(),
bounds: (F::neg_infinity(), F::infinity()),
}
}
pub const fn is_free(&self) -> bool {
self.index.is_some()
}
pub const fn is_fixed(&self) -> bool {
self.index.is_none()
}
}
impl<F: Field> Debug for Parameter<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.index.is_none() {
write!(
f,
"Parameter(name={}, value={} (fixed), bounds=({}, {}), parent={})",
self.name, self.initial, self.bounds.0, self.bounds.1, self.amplitude
)
} else {
write!(
f,
"Parameter(name={}, value={}, bounds=({}, {}), parent={})",
self.name, self.initial, self.bounds.0, self.bounds.1, self.amplitude
)
}
}
}
impl<F: Field> Display for Parameter<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
pub trait Node<F: Field>: Sync + Send + DynClone {
fn precalculate(&mut self, _dataset: &Dataset<F>) -> Result<(), RustitudeError> {
Ok(())
}
fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError>;
fn parameters(&self) -> Vec<String> {
vec![]
}
fn into_amplitude(self, name: &str) -> Amplitude<F>
where
Self: std::marker::Sized + 'static,
{
Amplitude::new(name, self)
}
fn named(self, name: &str) -> Amplitude<F>
where
Self: std::marker::Sized + 'static,
{
self.into_amplitude(name)
}
fn is_python_node(&self) -> bool {
false
}
}
dyn_clone::clone_trait_object!(<F> Node<F>);
pub trait AmpLike<F: Field>: Send + Sync + Debug + Display + AsTree + DynClone {
fn walk(&self) -> Vec<Amplitude<F>>;
fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>>;
fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>>;
fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
None
}
fn real(&self) -> Real<F>
where
Self: std::marker::Sized + 'static,
{
Real(dyn_clone::clone_box(self))
}
fn imag(&self) -> Imag<F>
where
Self: Sized + 'static,
{
Imag(dyn_clone::clone_box(self))
}
fn prod(als: &Vec<Box<dyn AmpLike<F>>>) -> Product<F>
where
Self: Sized + 'static,
{
Product(*dyn_clone::clone_box(als))
}
fn sum(als: &Vec<Box<dyn AmpLike<F>>>) -> Sum<F>
where
Self: Sized + 'static,
{
Sum(*dyn_clone::clone_box(als))
}
}
dyn_clone::clone_trait_object!(<F> AmpLike<F>);
pub trait AsTree {
fn get_tree(&self) -> String {
self._get_tree(&mut vec![])
}
fn _get_indent(&self, bits: Vec<bool>) -> String {
bits.iter()
.map(|b| if *b { " ┃ " } else { " " })
.join("")
}
fn _get_intermediate(&self) -> String {
String::from(" ┣━")
}
fn _get_end(&self) -> String {
String::from(" ┗━")
}
fn _get_tree(&self, bits: &mut Vec<bool>) -> String;
}
#[derive(Clone)]
pub struct Amplitude<F: Field> {
pub name: String,
pub node: Box<dyn Node<F>>,
pub active: bool,
pub parameters: Vec<String>,
pub cache_position: usize,
pub parameter_index_start: usize,
}
impl<F: Field> Debug for Amplitude<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
impl<F: Field> Display for Amplitude<F> {
#[rustfmt::skip]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Amplitude")?;
writeln!(f, " Name: {}", self.name)?;
writeln!(f, " Active: {}", self.active)?;
writeln!(f, " Cache Position: {}", self.cache_position)?;
writeln!(f, " Index of First Parameter: {}", self.parameter_index_start)
}
}
impl<F: Field> AsTree for Amplitude<F> {
fn _get_tree(&self, _bits: &mut Vec<bool>) -> String {
let name = if self.active {
self.name.clone()
} else {
format!("/* {} */", self.name)
};
if self.parameters().len() > 7 {
format!(" {}({},...)\n", name, self.parameters()[0..7].join(", "))
} else {
format!(" {}({})\n", name, self.parameters().join(", "))
}
}
}
impl<F: Field> Amplitude<F> {
pub fn new(name: &str, node: impl Node<F> + 'static) -> Self {
info!("Created new amplitude named {name}");
let parameters = node.parameters();
Self {
name: name.to_string(),
node: Box::new(node),
parameters,
active: true,
cache_position: 0,
parameter_index_start: 0,
}
}
pub fn register(
&mut self,
cache_position: usize,
parameter_index_start: usize,
dataset: &Dataset<F>,
) -> Result<(), RustitudeError> {
self.cache_position = cache_position;
self.parameter_index_start = parameter_index_start;
self.precalculate(dataset)
}
}
impl<F: Field> Node<F> for Amplitude<F> {
fn precalculate(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
self.node.precalculate(dataset)?;
debug!("Precalculated amplitude {}", self.name);
Ok(())
}
fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
let res = self.node.calculate(
¶meters
[self.parameter_index_start..self.parameter_index_start + self.parameters.len()],
event,
);
debug!(
"{}({:?}, event #{}) = {}",
self.name,
¶meters
[self.parameter_index_start..self.parameter_index_start + self.parameters.len()],
event.index,
res.as_ref()
.map(|c| c.to_string())
.unwrap_or_else(|e| e.to_string())
);
res
}
fn parameters(&self) -> Vec<String> {
self.node.parameters()
}
}
impl<F: Field> AmpLike<F> for Amplitude<F> {
fn walk(&self) -> Vec<Self> {
vec![self.clone()]
}
fn walk_mut(&mut self) -> Vec<&mut Self> {
vec![self]
}
fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
let res = cache[self.cache_position];
debug!(
"Computing {} from cache: {:?}",
self.name,
res.as_ref().map(|c| c.to_string())
);
res
}
}
#[derive(Clone)]
pub struct Real<F: Field>(Box<dyn AmpLike<F>>);
impl<F: Field> Debug for Real<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Real [ {:?} ]", self.0)
}
}
impl<F: Field> Display for Real<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{}", self.get_tree())
}
}
impl<F: Field> AmpLike<F> for Real<F> {
fn walk(&self) -> Vec<Amplitude<F>> {
self.0.walk()
}
fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
self.0.walk_mut()
}
fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
let res: Option<Complex<F>> = self.0.compute(cache).map(|r| r.re.into());
debug!(
"Computing {:?} from cache: {:?}",
self,
res.as_ref().map(|c| c.to_string())
);
res
}
}
impl<F: Field> AsTree for Real<F> {
fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
let mut res = String::from("[ real ]\n");
res.push_str(&self._get_indent(bits.to_vec()));
res.push_str(&self._get_end());
bits.push(false);
res.push_str(&self.0._get_tree(&mut bits.clone()));
bits.pop();
res
}
}
#[derive(Clone)]
pub struct Imag<F: Field>(Box<dyn AmpLike<F>>);
impl<F: Field> Debug for Imag<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Imag [ {:?} ]", self.0)
}
}
impl<F: Field> Display for Imag<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{}", self.get_tree())
}
}
impl<F: Field> AmpLike<F> for Imag<F> {
fn walk(&self) -> Vec<Amplitude<F>> {
self.0.walk()
}
fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
self.0.walk_mut()
}
fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
let res: Option<Complex<F>> = self.0.compute(cache).map(|r| r.im.into());
debug!(
"Computing {:?} from cache: {:?}",
self,
res.as_ref().map(|c| c.to_string())
);
res
}
}
impl<F: Field> AsTree for Imag<F> {
fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
let mut res = String::from("[ imag ]\n");
res.push_str(&self._get_indent(bits.to_vec()));
res.push_str(&self._get_end());
bits.push(false);
res.push_str(&self.0._get_tree(&mut bits.clone()));
bits.pop();
res
}
}
#[derive(Clone)]
pub struct Product<F: Field>(Vec<Box<dyn AmpLike<F>>>);
impl<F: Field> Debug for Product<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Product [ ")?;
for op in &self.0 {
write!(f, "{:?} ", op)?;
}
write!(f, "]")
}
}
impl<F: Field> Display for Product<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{}", self.get_tree())
}
}
impl<F: Field> AsTree for Product<F> {
fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
let mut res = String::from("[ * ]\n");
for (i, op) in self.0.iter().enumerate() {
res.push_str(&self._get_indent(bits.to_vec()));
if i == self.0.len() - 1 {
res.push_str(&self._get_end());
bits.push(false);
} else {
res.push_str(&self._get_intermediate());
bits.push(true);
}
res.push_str(&op._get_tree(&mut bits.clone()));
bits.pop();
}
res
}
}
impl<F: Field> AmpLike<F> for Product<F> {
fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
Some(self.0.clone())
}
fn walk(&self) -> Vec<Amplitude<F>> {
self.0.iter().flat_map(|op| op.walk()).collect()
}
fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
self.0.iter_mut().flat_map(|op| op.walk_mut()).collect()
}
fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
let mut values = self.0.iter().filter_map(|op| op.compute(cache)).peekable();
let res: Option<Complex<F>> = if values.peek().is_none() {
Some(Complex::default())
} else {
Some(values.product())
};
debug!(
"Computing {:?} from cache: {:?}",
self,
res.as_ref().map(|c| c.to_string())
);
res
}
}
#[derive(Clone)]
pub struct Sum<F: Field>(pub Vec<Box<dyn AmpLike<F>>>);
impl<F: Field> Debug for Sum<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Sum [ ")?;
for op in &self.0 {
write!(f, "{:?} ", op)?;
}
write!(f, "]")
}
}
impl<F: Field> Display for Sum<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{}", self.get_tree())
}
}
impl<F: Field> AsTree for Sum<F> {
fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
let mut res = String::from("[ + ]\n");
for (i, op) in self.0.iter().enumerate() {
res.push_str(&self._get_indent(bits.to_vec()));
if i == self.0.len() - 1 {
res.push_str(&self._get_end());
bits.push(false);
} else {
res.push_str(&self._get_intermediate());
bits.push(true);
}
res.push_str(&op._get_tree(&mut bits.clone()));
bits.pop();
}
res
}
}
impl<F: Field> AmpLike<F> for Sum<F> {
fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
Some(self.0.clone())
}
fn walk(&self) -> Vec<Amplitude<F>> {
self.0.iter().flat_map(|op| op.walk()).collect()
}
fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
self.0.iter_mut().flat_map(|op| op.walk_mut()).collect()
}
fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
let res = Some(
self.0
.iter()
.filter_map(|al| al.compute(cache))
.sum::<Complex<F>>(),
);
debug!(
"Computing {:?} from cache: {:?}",
self,
res.as_ref().map(|c| c.to_string())
);
res
}
}
#[derive(Clone)]
pub struct NormSqr<F: Field>(pub Box<dyn AmpLike<F>>);
impl<F: Field> Debug for NormSqr<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NormSqr[ {:?} ]", self.0)
}
}
impl<F: Field> Display for NormSqr<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{}", self.get_tree())
}
}
impl<F: Field> AsTree for NormSqr<F> {
fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
let mut res = String::from("[ |_|^2 ]\n");
res.push_str(&self._get_indent(bits.to_vec()));
res.push_str(&self._get_end());
bits.push(false);
res.push_str(&self.0._get_tree(&mut bits.clone()));
bits.pop();
res
}
}
impl<F: Field> NormSqr<F> {
pub fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<F> {
self.0.compute(cache).map(|res| res.norm_sqr())
}
pub fn walk(&self) -> Vec<Amplitude<F>> {
self.0.walk()
}
pub fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
self.0.walk_mut()
}
}
#[derive(Clone)]
pub struct Model<F: Field> {
pub cohsums: Vec<NormSqr<F>>,
pub amplitudes: Arc<RwLock<Vec<Amplitude<F>>>>,
pub parameters: Vec<Parameter<F>>,
pub contains_python_amplitudes: bool,
}
impl<F: Field> Debug for Model<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Model [ ")?;
for op in &self.cohsums {
write!(f, "{:?} ", op)?;
}
write!(f, "]")
}
}
impl<F: Field> Display for Model<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{}", self.get_tree())
}
}
impl<F: Field> AsTree for Model<F> {
fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
let mut res = String::from("[ + ]\n");
for (i, op) in self.cohsums.iter().enumerate() {
res.push_str(&self._get_indent(bits.to_vec()));
if i == self.cohsums.len() - 1 {
res.push_str(&self._get_end());
bits.push(false);
} else {
res.push_str(&self._get_intermediate());
bits.push(true);
}
res.push_str(&op._get_tree(&mut bits.clone()));
bits.pop();
}
res
}
}
impl<F: Field> Model<F> {
pub fn new(amps: &[Box<dyn AmpLike<F>>]) -> Self {
let mut amp_names = HashSet::new();
let amplitudes: Vec<Amplitude<F>> = amps
.iter()
.flat_map(|cohsum| cohsum.walk())
.filter_map(|amp| {
if amp_names.insert(amp.name.clone()) {
Some(amp)
} else {
None
}
})
.collect();
let parameter_tags: Vec<(String, String)> = amplitudes
.iter()
.flat_map(|amp| {
amp.parameters()
.iter()
.map(|p| (amp.name.clone(), p.clone()))
.collect::<Vec<_>>()
})
.collect();
let parameters = parameter_tags
.iter()
.enumerate()
.map(|(i, (amp_name, par_name))| Parameter::new(amp_name, par_name, i))
.collect();
let contains_python_amplitudes = amplitudes.iter().any(|amp| amp.node.is_python_node());
Self {
cohsums: amps.iter().map(|inner| NormSqr(inner.clone())).collect(),
amplitudes: Arc::new(RwLock::new(amplitudes)),
parameters,
contains_python_amplitudes,
}
}
pub fn deep_clone(&self) -> Self {
Self {
cohsums: self.cohsums.clone(),
amplitudes: Arc::new(RwLock::new(self.amplitudes.read().clone())),
parameters: self.parameters.clone(),
contains_python_amplitudes: self.contains_python_amplitudes,
}
}
pub fn compute(
&self,
amplitudes: &[Amplitude<F>],
parameters: &[F],
event: &Event<F>,
) -> Result<F, RustitudeError> {
let cache: Vec<Option<Complex<F>>> = amplitudes
.iter()
.map(|amp| {
if amp.active {
amp.calculate(parameters, event).map(Some)
} else {
Ok(None)
}
})
.collect::<Result<Vec<Option<Complex<F>>>, RustitudeError>>()?;
Ok(self
.cohsums
.iter()
.filter_map(|cohsum| cohsum.compute(&cache))
.sum::<F>())
}
pub fn load(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
let mut next_cache_pos = 0;
let mut parameter_index = 0;
self.amplitudes.write().iter_mut().try_for_each(|amp| {
amp.register(next_cache_pos, parameter_index, dataset)?;
self.cohsums.iter_mut().for_each(|cohsum| {
cohsum.walk_mut().iter_mut().for_each(|r_amp| {
if r_amp.name == amp.name {
r_amp.cache_position = next_cache_pos;
r_amp.parameter_index_start = parameter_index;
}
})
});
next_cache_pos += 1;
parameter_index += amp.parameters().len();
Ok(())
})
}
pub fn get_amplitude(&self, amplitude_name: &str) -> Result<Amplitude<F>, RustitudeError> {
self.amplitudes
.read()
.iter()
.find(|a: &&Amplitude<F>| a.name == amplitude_name)
.ok_or_else(|| RustitudeError::AmplitudeNotFoundError(amplitude_name.to_string()))
.cloned()
}
pub fn get_parameter(
&self,
amplitude_name: &str,
parameter_name: &str,
) -> Result<Parameter<F>, RustitudeError> {
self.get_amplitude(amplitude_name)?;
self.parameters
.iter()
.find(|p: &&Parameter<F>| p.amplitude == amplitude_name && p.name == parameter_name)
.ok_or_else(|| RustitudeError::ParameterNotFoundError(parameter_name.to_string()))
.cloned()
}
pub fn print_parameters(&self) {
let any_fixed = if self.any_fixed() { 1 } else { 0 };
if self.any_fixed() {
println!(
"Fixed: {}",
self.group_by_index()[0]
.iter()
.map(|p| format!("{:?}", p))
.join(", ")
);
}
for (i, group) in self.group_by_index().iter().skip(any_fixed).enumerate() {
println!(
"{}: {}",
i,
group.iter().map(|p| format!("{:?}", p)).join(", ")
);
}
}
pub fn free_parameters(&self) -> Vec<Parameter<F>> {
self.parameters
.iter()
.filter(|p| p.is_free())
.cloned()
.collect()
}
pub fn fixed_parameters(&self) -> Vec<Parameter<F>> {
self.parameters
.iter()
.filter(|p| p.is_fixed())
.cloned()
.collect()
}
pub fn constrain(
&mut self,
amplitude_1: &str,
parameter_1: &str,
amplitude_2: &str,
parameter_2: &str,
) -> Result<(), RustitudeError> {
let p1 = self.get_parameter(amplitude_1, parameter_1)?;
let p2 = self.get_parameter(amplitude_2, parameter_2)?;
for par in self.parameters.iter_mut() {
match p1.index.cmp(&p2.index) {
std::cmp::Ordering::Less => {
if par.index == p2.index {
par.index = p1.index;
par.initial = p1.initial;
par.fixed_index = p1.fixed_index;
}
}
std::cmp::Ordering::Equal => unimplemented!(),
std::cmp::Ordering::Greater => {
if par.index == p1.index {
par.index = p2.index;
par.initial = p2.initial;
par.fixed_index = p2.fixed_index;
}
}
}
}
self.reindex_parameters();
Ok(())
}
pub fn fix(
&mut self,
amplitude: &str,
parameter: &str,
value: F,
) -> Result<(), RustitudeError> {
let search_par = self.get_parameter(amplitude, parameter)?;
let fixed_index = self.get_min_fixed_index();
for par in self.parameters.iter_mut() {
if par.index == search_par.index {
par.index = None;
par.initial = value;
par.fixed_index = fixed_index;
}
}
self.reindex_parameters();
Ok(())
}
pub fn free(&mut self, amplitude: &str, parameter: &str) -> Result<(), RustitudeError> {
let search_par = self.get_parameter(amplitude, parameter)?;
let index = self.get_min_free_index();
for par in self.parameters.iter_mut() {
if par.fixed_index == search_par.fixed_index {
par.index = index;
par.fixed_index = None;
}
}
self.reindex_parameters();
Ok(())
}
pub fn set_bounds(
&mut self,
amplitude: &str,
parameter: &str,
bounds: (F, F),
) -> Result<(), RustitudeError> {
let search_par = self.get_parameter(amplitude, parameter)?;
if search_par.index.is_some() {
for par in self.parameters.iter_mut() {
if par.index == search_par.index {
par.bounds = bounds;
}
}
} else {
for par in self.parameters.iter_mut() {
if par.fixed_index == search_par.fixed_index {
par.bounds = bounds;
}
}
}
Ok(())
}
pub fn set_initial(
&mut self,
amplitude: &str,
parameter: &str,
initial: F,
) -> Result<(), RustitudeError> {
let search_par = self.get_parameter(amplitude, parameter)?;
if search_par.index.is_some() {
for par in self.parameters.iter_mut() {
if par.index == search_par.index {
par.initial = initial;
}
}
} else {
for par in self.parameters.iter_mut() {
if par.fixed_index == search_par.fixed_index {
par.initial = initial;
}
}
}
Ok(())
}
pub fn get_bounds(&self) -> Vec<(F, F)> {
let any_fixed = if self.any_fixed() { 1 } else { 0 };
self.group_by_index()
.iter()
.skip(any_fixed)
.filter_map(|group| group.first().map(|par| par.bounds))
.collect()
}
pub fn get_initial(&self) -> Vec<F> {
let any_fixed = if self.any_fixed() { 1 } else { 0 };
self.group_by_index()
.iter()
.skip(any_fixed)
.filter_map(|group| group.first().map(|par| par.initial))
.collect()
}
pub fn get_n_free(&self) -> usize {
self.get_min_free_index().unwrap_or(0)
}
pub fn activate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
if !self.amplitudes.read().iter().any(|a| a.name == amplitude) {
return Err(RustitudeError::AmplitudeNotFoundError(
amplitude.to_string(),
));
}
self.amplitudes.write().iter_mut().for_each(|amp| {
if amp.name == amplitude {
amp.active = true
}
});
self.cohsums.iter_mut().for_each(|cohsum| {
cohsum.walk_mut().iter_mut().for_each(|amp| {
if amp.name == amplitude {
amp.active = true
}
})
});
Ok(())
}
pub fn activate_all(&mut self) {
self.amplitudes
.write()
.iter_mut()
.for_each(|amp| amp.active = true);
self.cohsums.iter_mut().for_each(|cohsum| {
cohsum
.walk_mut()
.iter_mut()
.for_each(|amp| amp.active = true)
});
}
pub fn isolate(&mut self, amplitudes: Vec<&str>) -> Result<(), RustitudeError> {
self.deactivate_all();
for amplitude in amplitudes {
self.activate(amplitude)?;
}
Ok(())
}
pub fn deactivate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
if !self.amplitudes.read().iter().any(|a| a.name == amplitude) {
return Err(RustitudeError::AmplitudeNotFoundError(
amplitude.to_string(),
));
}
self.amplitudes.write().iter_mut().for_each(|amp| {
if amp.name == amplitude {
amp.active = false
}
});
self.cohsums.iter_mut().for_each(|cohsum| {
cohsum.walk_mut().iter_mut().for_each(|amp| {
if amp.name == amplitude {
amp.active = false
}
})
});
Ok(())
}
pub fn deactivate_all(&mut self) {
self.amplitudes
.write()
.iter_mut()
.for_each(|amp| amp.active = false);
self.cohsums.iter_mut().for_each(|cohsum| {
cohsum
.walk_mut()
.iter_mut()
.for_each(|amp| amp.active = false)
});
}
fn group_by_index(&self) -> Vec<Vec<&Parameter<F>>> {
self.parameters
.iter()
.sorted_by_key(|par| par.index)
.chunk_by(|par| par.index)
.into_iter()
.map(|(_, group)| group.collect::<Vec<_>>())
.collect()
}
fn group_by_index_mut(&mut self) -> Vec<Vec<&mut Parameter<F>>> {
self.parameters
.iter_mut()
.sorted_by_key(|par| par.index)
.chunk_by(|par| par.index)
.into_iter()
.map(|(_, group)| group.collect())
.collect()
}
fn any_fixed(&self) -> bool {
self.parameters.iter().any(|p| p.index.is_none())
}
fn reindex_parameters(&mut self) {
let any_fixed = if self.any_fixed() { 1 } else { 0 };
self.group_by_index_mut()
.iter_mut()
.skip(any_fixed) .enumerate()
.for_each(|(ind, par_group)| par_group.iter_mut().for_each(|par| par.index = Some(ind)))
}
fn get_min_free_index(&self) -> Option<usize> {
self.parameters
.iter()
.filter_map(|p| p.index)
.max()
.map_or(Some(0), |max| Some(max + 1))
}
fn get_min_fixed_index(&self) -> Option<usize> {
self.parameters
.iter()
.filter_map(|p| p.fixed_index)
.max()
.map_or(Some(0), |max| Some(max + 1))
}
}
#[derive(Clone)]
pub struct Scalar;
impl<F: Field> Node<F> for Scalar {
fn parameters(&self) -> Vec<String> {
vec!["value".to_string()]
}
fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
Ok(Complex::new(parameters[0], F::zero()))
}
}
pub fn scalar<F: Field>(name: &str) -> Amplitude<F> {
Amplitude::new(name, Scalar)
}
#[derive(Clone)]
pub struct ComplexScalar;
impl<F: Field> Node<F> for ComplexScalar {
fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
Ok(Complex::new(parameters[0], parameters[1]))
}
fn parameters(&self) -> Vec<String> {
vec!["real".to_string(), "imag".to_string()]
}
}
pub fn cscalar<F: Field>(name: &str) -> Amplitude<F> {
Amplitude::new(name, ComplexScalar)
}
#[derive(Clone)]
pub struct PolarComplexScalar;
impl<F: Field> Node<F> for PolarComplexScalar {
fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
Ok(Complex::cis(parameters[1]).mul(parameters[0]))
}
fn parameters(&self) -> Vec<String> {
vec!["mag".to_string(), "phi".to_string()]
}
}
pub fn pcscalar<F: Field>(name: &str) -> Amplitude<F> {
Amplitude::new(name, PolarComplexScalar)
}
#[derive(Clone)]
pub struct Piecewise<V, F>
where
V: Fn(&Event<F>) -> F + Send + Sync + Copy,
F: Field,
{
edges: Vec<(F, F)>,
variable: V,
calculated_variable: Vec<F>,
}
impl<V, F> Piecewise<V, F>
where
V: Fn(&Event<F>) -> F + Send + Sync + Copy,
F: Field,
{
pub fn new(bins: usize, range: (F, F), variable: V) -> Self {
let diff = (range.1 - range.0) / convert!(bins, F);
let edges = (0..bins)
.map(|i| {
(
F::mul_add(convert!(i, F), diff, range.0),
F::mul_add(convert!(i + 1, F), diff, range.0),
)
})
.collect();
Self {
edges,
variable,
calculated_variable: Vec::default(),
}
}
}
impl<V, F> Node<F> for Piecewise<V, F>
where
V: Fn(&Event<F>) -> F + Send + Sync + Copy,
F: Field,
{
fn precalculate(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
self.calculated_variable = dataset.events.par_iter().map(self.variable).collect();
Ok(())
}
fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
let val = self.calculated_variable[event.index];
let opt_i_bin = self.edges.iter().position(|&(l, r)| val >= l && val <= r);
opt_i_bin.map_or_else(
|| Ok(Complex::default()),
|i_bin| {
Ok(Complex::new(
parameters[i_bin * 2],
parameters[(i_bin * 2) + 1],
))
},
)
}
fn parameters(&self) -> Vec<String> {
(0..self.edges.len())
.flat_map(|i| vec![format!("bin {} re", i), format!("bin {} im", i)])
.collect()
}
}
pub fn piecewise_m<F: Field + 'static>(name: &str, bins: usize, range: (F, F)) -> Amplitude<F> {
Amplitude::new(
name,
Piecewise::new(bins, range, |e: &Event<F>| {
(e.daughter_p4s[0] + e.daughter_p4s[1]).m()
}),
)
}
macro_rules! impl_sum {
($t:ident, $a:ty, $b:ty) => {
impl<$t: Field + 'static> Add<$b> for $a {
type Output = Sum<$t>;
fn add(self, rhs: $b) -> Self::Output {
Sum(vec![Box::new(self), Box::new(rhs)])
}
}
impl<$t: Field + 'static> Add<&$b> for &$a {
type Output = <$a as Add<$b>>::Output;
fn add(self, rhs: &$b) -> Self::Output {
<$a as Add<$b>>::add(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Add<&$b> for $a {
type Output = <$a as Add<$b>>::Output;
fn add(self, rhs: &$b) -> Self::Output {
<$a as Add<$b>>::add(self, rhs.clone())
}
}
impl<$t: Field + 'static> Add<$b> for &$a {
type Output = <$a as Add<$b>>::Output;
fn add(self, rhs: $b) -> Self::Output {
<$a as Add<$b>>::add(self.clone(), rhs)
}
}
impl<$t: Field + 'static> Add<$a> for $b {
type Output = Sum<$t>;
fn add(self, rhs: $a) -> Self::Output {
Sum(vec![Box::new(self), Box::new(rhs)])
}
}
impl<$t: Field + 'static> Add<&$a> for &$b {
type Output = <$b as Add<$a>>::Output;
fn add(self, rhs: &$a) -> Self::Output {
<$b as Add<$a>>::add(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Add<&$a> for $b {
type Output = <$b as Add<$a>>::Output;
fn add(self, rhs: &$a) -> Self::Output {
<$b as Add<$a>>::add(self, rhs.clone())
}
}
impl<$t: Field + 'static> Add<$a> for &$b {
type Output = <$b as Add<$a>>::Output;
fn add(self, rhs: $a) -> Self::Output {
<$b as Add<$a>>::add(self.clone(), rhs)
}
}
};
($t:ident, $a:ty) => {
impl<$t: Field + 'static> Add<$a> for $a {
type Output = Sum<$t>;
fn add(self, rhs: $a) -> Self::Output {
Sum(vec![Box::new(self), Box::new(rhs)])
}
}
impl<$t: Field + 'static> Add<&$a> for &$a {
type Output = <$a as Add<$a>>::Output;
fn add(self, rhs: &$a) -> Self::Output {
<$a as Add<$a>>::add(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Add<&$a> for $a {
type Output = <$a as Add<$a>>::Output;
fn add(self, rhs: &$a) -> Self::Output {
<$a as Add<$a>>::add(self, rhs.clone())
}
}
impl<$t: Field + 'static> Add<$a> for &$a {
type Output = <$a as Add<$a>>::Output;
fn add(self, rhs: $a) -> Self::Output {
<$a as Add<$a>>::add(self.clone(), rhs)
}
}
};
}
macro_rules! impl_appending_sum {
($t:ident, $a:ty) => {
impl<$t: Field + 'static> Add<Sum<$t>> for $a {
type Output = Sum<$t>;
fn add(self, rhs: Sum<$t>) -> Self::Output {
let mut terms = rhs.0;
terms.insert(0, Box::new(self));
Sum(terms)
}
}
impl<$t: Field + 'static> Add<$a> for Sum<$t> {
type Output = Sum<$t>;
fn add(self, rhs: $a) -> Self::Output {
let mut terms = self.0;
terms.push(Box::new(rhs));
Sum(terms)
}
}
impl<$t: Field + 'static> Add<&Sum<$t>> for &$a {
type Output = <$a as Add<Sum<$t>>>::Output;
fn add(self, rhs: &Sum<$t>) -> Self::Output {
<$a as Add<Sum<$t>>>::add(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Add<&Sum<$t>> for $a {
type Output = <$a as Add<Sum<$t>>>::Output;
fn add(self, rhs: &Sum<$t>) -> Self::Output {
<$a as Add<Sum<$t>>>::add(self, rhs.clone())
}
}
impl<$t: Field + 'static> Add<Sum<$t>> for &$a {
type Output = <$a as Add<Sum<$t>>>::Output;
fn add(self, rhs: Sum<$t>) -> Self::Output {
<$a as Add<Sum<$t>>>::add(self.clone(), rhs)
}
}
impl<$t: Field + 'static> Add<&$a> for &Sum<$t> {
type Output = <Sum<$t> as Add<$a>>::Output;
fn add(self, rhs: &$a) -> Self::Output {
<Sum<$t> as Add<$a>>::add(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Add<&$a> for Sum<$t> {
type Output = <Sum<$t> as Add<$a>>::Output;
fn add(self, rhs: &$a) -> Self::Output {
<Sum<$t> as Add<$a>>::add(self, rhs.clone())
}
}
impl<$t: Field + 'static> Add<$a> for &Sum<$t> {
type Output = <Sum<$t> as Add<$a>>::Output;
fn add(self, rhs: $a) -> Self::Output {
<Sum<$t> as Add<$a>>::add(self.clone(), rhs)
}
}
};
}
macro_rules! impl_prod {
($t:ident, $a:ty, $b:ty) => {
impl<$t: Field + 'static> Mul<$b> for $a {
type Output = Product<$t>;
fn mul(self, rhs: $b) -> Self::Output {
match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
(Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
(None, Some(terms)) => {
let mut terms = terms;
terms.insert(0, Box::new(self));
Product(terms)
}
(Some(terms), None) => {
let mut terms = terms;
terms.push(Box::new(rhs));
Product(terms)
}
(None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
}
}
}
impl<$t: Field + 'static> Mul<&$b> for &$a {
type Output = <$a as Mul<$b>>::Output;
fn mul(self, rhs: &$b) -> Self::Output {
<$a as Mul<$b>>::mul(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Mul<&$b> for $a {
type Output = <$a as Mul<$b>>::Output;
fn mul(self, rhs: &$b) -> Self::Output {
<$a as Mul<$b>>::mul(self, rhs.clone())
}
}
impl<$t: Field + 'static> Mul<$b> for &$a {
type Output = <$a as Mul<$b>>::Output;
fn mul(self, rhs: $b) -> Self::Output {
<$a as Mul<$b>>::mul(self.clone(), rhs)
}
}
impl<$t: Field + 'static> Mul<$a> for $b {
type Output = Product<$t>;
fn mul(self, rhs: $a) -> Self::Output {
match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
(Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
(None, Some(terms)) => {
let mut terms = terms;
terms.insert(0, Box::new(self));
Product(terms)
}
(Some(terms), None) => {
let mut terms = terms;
terms.push(Box::new(rhs));
Product(terms)
}
(None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
}
}
}
impl<$t: Field + 'static> Mul<&$a> for &$b {
type Output = <$b as Mul<$a>>::Output;
fn mul(self, rhs: &$a) -> Self::Output {
<$b as Mul<$a>>::mul(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Mul<&$a> for $b {
type Output = <$b as Mul<$a>>::Output;
fn mul(self, rhs: &$a) -> Self::Output {
<$b as Mul<$a>>::mul(self, rhs.clone())
}
}
impl<$t: Field + 'static> Mul<$a> for &$b {
type Output = <$b as Mul<$a>>::Output;
fn mul(self, rhs: $a) -> Self::Output {
<$b as Mul<$a>>::mul(self.clone(), rhs)
}
}
};
($t:ident, $a:ty) => {
impl<$t: Field + 'static> Mul<$a> for $a {
type Output = Product<$t>;
fn mul(self, rhs: $a) -> Self::Output {
match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
(Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
(None, Some(terms)) => {
let mut terms = terms;
terms.insert(0, Box::new(self));
Product(terms)
}
(Some(terms), None) => {
let mut terms = terms;
terms.push(Box::new(rhs));
Product(terms)
}
(None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
}
}
}
impl<$t: Field + 'static> Mul<&$a> for &$a {
type Output = <$a as Mul<$a>>::Output;
fn mul(self, rhs: &$a) -> Self::Output {
<$a as Mul<$a>>::mul(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Mul<&$a> for $a {
type Output = <$a as Mul<$a>>::Output;
fn mul(self, rhs: &$a) -> Self::Output {
<$a as Mul<$a>>::mul(self, rhs.clone())
}
}
impl<$t: Field + 'static> Mul<$a> for &$a {
type Output = <$a as Mul<$a>>::Output;
fn mul(self, rhs: $a) -> Self::Output {
<$a as Mul<$a>>::mul(self.clone(), rhs)
}
}
};
}
macro_rules! impl_box_prod {
($t:ident, $a:ty) => {
impl<$t: Field + 'static> Mul<Box<dyn AmpLike<$t>>> for $a {
type Output = Product<$t>;
fn mul(self, rhs: Box<dyn AmpLike<$t>>) -> Self::Output {
match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
(Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
(None, Some(terms)) => {
let mut terms = terms;
terms.insert(0, Box::new(self));
Product(terms)
}
(Some(terms), None) => {
let mut terms = terms;
terms.push(Box::new(self));
Product(terms)
}
(None, None) => Product(vec![Box::new(self), rhs]),
}
}
}
impl<$t: Field + 'static> Mul<$a> for Box<dyn AmpLike<$t>> {
type Output = Product<$t>;
fn mul(self, rhs: $a) -> Self::Output {
match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
(Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
(None, Some(terms)) => {
let mut terms = terms;
terms.insert(0, self);
Product(terms)
}
(Some(terms), None) => {
let mut terms = terms;
terms.push(self);
Product(terms)
}
(None, None) => Product(vec![self, Box::new(rhs)]),
}
}
}
};
}
macro_rules! impl_box_sum {
($t:ident, $a:ty) => {
impl<$t: Field + 'static> Add<Box<dyn AmpLike<$t>>> for $a {
type Output = Sum<$t>;
fn add(self, rhs: Box<dyn AmpLike<$t>>) -> Self::Output {
match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
(Some(terms_a), Some(terms_b)) => Sum([terms_a, terms_b].concat()),
(None, Some(terms)) => {
let mut terms = terms;
terms.insert(0, Box::new(self));
Sum(terms)
}
(Some(terms), None) => {
let mut terms = terms;
terms.push(Box::new(self));
Sum(terms)
}
(None, None) => Sum(vec![Box::new(self), rhs]),
}
}
}
impl<$t: Field + 'static> Add<$a> for Box<dyn AmpLike<$t>> {
type Output = Sum<$t>;
fn add(self, rhs: $a) -> Self::Output {
match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
(Some(terms_a), Some(terms_b)) => Sum([terms_a, terms_b].concat()),
(None, Some(terms)) => {
let mut terms = terms;
terms.insert(0, self);
Sum(terms)
}
(Some(terms), None) => {
let mut terms = terms;
terms.push(self);
Sum(terms)
}
(None, None) => Sum(vec![self, Box::new(rhs)]),
}
}
}
};
}
macro_rules! impl_dist {
($t:ident, $a:ty) => {
impl<$t: Field + 'static> Mul<Sum<$t>> for $a {
type Output = Sum<$t>;
fn mul(self, rhs: Sum<$t>) -> Self::Output {
let mut terms = vec![];
for term in rhs.0 {
terms.push(Box::new(self.clone() * term) as Box<dyn AmpLike<$t>>);
}
Sum(terms)
}
}
impl<$t: Field + 'static> Mul<$a> for Sum<$t> {
type Output = Sum<$t>;
fn mul(self, rhs: $a) -> Self::Output {
let mut terms = vec![];
for term in self.0 {
terms.push(Box::new(term * rhs.clone()) as Box<dyn AmpLike<$t>>);
}
Sum(terms)
}
}
impl<$t: Field + 'static> Mul<&$a> for &Sum<$t> {
type Output = <Sum<$t> as Mul<$a>>::Output;
fn mul(self, rhs: &$a) -> Self::Output {
<Sum<$t> as Mul<$a>>::mul(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Mul<&$a> for Sum<$t> {
type Output = <Sum<$t> as Mul<$a>>::Output;
fn mul(self, rhs: &$a) -> Self::Output {
<Sum<$t> as Mul<$a>>::mul(self, rhs.clone())
}
}
impl<$t: Field + 'static> Mul<$a> for &Sum<$t> {
type Output = <Sum<$t> as Mul<$a>>::Output;
fn mul(self, rhs: $a) -> Self::Output {
<Sum<$t> as Mul<$a>>::mul(self.clone(), rhs)
}
}
impl<$t: Field + 'static> Mul<&Sum<$t>> for &$a {
type Output = <$a as Mul<Sum<$t>>>::Output;
fn mul(self, rhs: &Sum<$t>) -> Self::Output {
<$a as Mul<Sum<$t>>>::mul(self.clone(), rhs.clone())
}
}
impl<$t: Field + 'static> Mul<&Sum<$t>> for $a {
type Output = <$a as Mul<Sum<$t>>>::Output;
fn mul(self, rhs: &Sum<$t>) -> Self::Output {
<$a as Mul<Sum<$t>>>::mul(self, rhs.clone())
}
}
impl<$t: Field + 'static> Mul<Sum<$t>> for &$a {
type Output = <$a as Mul<Sum<$t>>>::Output;
fn mul(self, rhs: Sum<$t>) -> Self::Output {
<$a as Mul<Sum<$t>>>::mul(self.clone(), rhs)
}
}
};
}
impl_sum!(F, Amplitude<F>);
impl_box_sum!(F, Amplitude<F>);
impl_sum!(F, Real<F>);
impl_box_sum!(F, Real<F>);
impl_sum!(F, Imag<F>);
impl_box_sum!(F, Imag<F>);
impl_sum!(F, Product<F>);
impl_box_sum!(F, Product<F>);
impl_box_sum!(F, Sum<F>);
impl_sum!(F, Amplitude<F>, Real<F>);
impl_sum!(F, Amplitude<F>, Imag<F>);
impl_sum!(F, Amplitude<F>, Product<F>);
impl_sum!(F, Real<F>, Imag<F>);
impl_sum!(F, Real<F>, Product<F>);
impl_sum!(F, Imag<F>, Product<F>);
impl_appending_sum!(F, Amplitude<F>);
impl_appending_sum!(F, Real<F>);
impl_appending_sum!(F, Imag<F>);
impl_appending_sum!(F, Product<F>);
impl_prod!(F, Amplitude<F>);
impl_box_prod!(F, Amplitude<F>);
impl_prod!(F, Real<F>);
impl_box_prod!(F, Real<F>);
impl_prod!(F, Imag<F>);
impl_box_prod!(F, Imag<F>);
impl_prod!(F, Product<F>);
impl_box_prod!(F, Product<F>);
impl_prod!(F, Amplitude<F>, Real<F>);
impl_prod!(F, Amplitude<F>, Imag<F>);
impl_prod!(F, Amplitude<F>, Product<F>);
impl_prod!(F, Real<F>, Imag<F>);
impl_prod!(F, Real<F>, Product<F>);
impl_prod!(F, Imag<F>, Product<F>);
impl_dist!(F, Amplitude<F>);
impl_dist!(F, Real<F>);
impl_dist!(F, Imag<F>);
impl_dist!(F, Product<F>);
impl<F: Field> Add<Self> for Sum<F> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self([self.0, rhs.0].concat())
}
}
impl<F: Field> Add<&Sum<F>> for &Sum<F> {
type Output = <Sum<F> as Add<Sum<F>>>::Output;
fn add(self, rhs: &Sum<F>) -> Self::Output {
<Sum<F> as Add<Sum<F>>>::add(self.clone(), rhs.clone())
}
}
impl<F: Field> Add<&Self> for Sum<F> {
type Output = <Self as Add<Self>>::Output;
fn add(self, rhs: &Self) -> Self::Output {
<Self as Add<Self>>::add(self, rhs.clone())
}
}
impl<F: Field> Add<Sum<F>> for &Sum<F> {
type Output = <Sum<F> as Add<Sum<F>>>::Output;
fn add(self, rhs: Sum<F>) -> Self::Output {
<Sum<F> as Add<Sum<F>>>::add(self.clone(), rhs)
}
}