use std::{array, collections::HashMap};
use indexmap::IndexSet;
use nalgebra::{SMatrix, SVector};
use num::Complex;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use crate::{
amplitudes::{AmplitudeID, ParameterLike},
Float, LadduError,
};
#[derive(Debug)]
pub struct Parameters<'a> {
pub(crate) parameters: &'a [Float],
pub(crate) constants: &'a [Float],
}
impl<'a> Parameters<'a> {
pub fn new(parameters: &'a [Float], constants: &'a [Float]) -> Self {
Self {
parameters,
constants,
}
}
pub fn get(&self, pid: ParameterID) -> Float {
match pid {
ParameterID::Parameter(index) => self.parameters[index],
ParameterID::Constant(index) => self.constants[index],
ParameterID::Uninit => panic!("Parameter has not been registered!"),
}
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.parameters.len()
}
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct Resources {
amplitudes: HashMap<String, AmplitudeID>,
pub active: Vec<bool>,
pub parameters: IndexSet<String>,
pub constants: Vec<Float>,
pub caches: Vec<Cache>,
scalar_cache_names: HashMap<String, usize>,
complex_scalar_cache_names: HashMap<String, usize>,
vector_cache_names: HashMap<String, usize>,
complex_vector_cache_names: HashMap<String, usize>,
matrix_cache_names: HashMap<String, usize>,
complex_matrix_cache_names: HashMap<String, usize>,
cache_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Cache(Vec<Float>);
impl Cache {
fn new(cache_size: usize) -> Self {
Self(vec![0.0; cache_size])
}
pub fn store_scalar(&mut self, sid: ScalarID, value: Float) {
self.0[sid.0] = value;
}
pub fn store_complex_scalar(&mut self, csid: ComplexScalarID, value: Complex<Float>) {
self.0[csid.0] = value.re;
self.0[csid.1] = value.im;
}
pub fn store_vector<const R: usize>(&mut self, vid: VectorID<R>, value: SVector<Float, R>) {
vid.0
.into_iter()
.enumerate()
.for_each(|(vi, i)| self.0[i] = value[vi]);
}
pub fn store_complex_vector<const R: usize>(
&mut self,
cvid: ComplexVectorID<R>,
value: SVector<Complex<Float>, R>,
) {
cvid.0
.into_iter()
.enumerate()
.for_each(|(vi, i)| self.0[i] = value[vi].re);
cvid.1
.into_iter()
.enumerate()
.for_each(|(vi, i)| self.0[i] = value[vi].im);
}
pub fn store_matrix<const R: usize, const C: usize>(
&mut self,
mid: MatrixID<R, C>,
value: SMatrix<Float, R, C>,
) {
mid.0.into_iter().enumerate().for_each(|(vi, row)| {
row.into_iter()
.enumerate()
.for_each(|(vj, k)| self.0[k] = value[(vi, vj)])
});
}
pub fn store_complex_matrix<const R: usize, const C: usize>(
&mut self,
cmid: ComplexMatrixID<R, C>,
value: SMatrix<Complex<Float>, R, C>,
) {
cmid.0.into_iter().enumerate().for_each(|(vi, row)| {
row.into_iter()
.enumerate()
.for_each(|(vj, k)| self.0[k] = value[(vi, vj)].re)
});
cmid.1.into_iter().enumerate().for_each(|(vi, row)| {
row.into_iter()
.enumerate()
.for_each(|(vj, k)| self.0[k] = value[(vi, vj)].im)
});
}
pub fn get_scalar(&self, sid: ScalarID) -> Float {
self.0[sid.0]
}
pub fn get_complex_scalar(&self, csid: ComplexScalarID) -> Complex<Float> {
Complex::new(self.0[csid.0], self.0[csid.1])
}
pub fn get_vector<const R: usize>(&self, vid: VectorID<R>) -> SVector<Float, R> {
SVector::from_fn(|i, _| self.0[vid.0[i]])
}
pub fn get_complex_vector<const R: usize>(
&self,
cvid: ComplexVectorID<R>,
) -> SVector<Complex<Float>, R> {
SVector::from_fn(|i, _| Complex::new(self.0[cvid.0[i]], self.0[cvid.1[i]]))
}
pub fn get_matrix<const R: usize, const C: usize>(
&self,
mid: MatrixID<R, C>,
) -> SMatrix<Float, R, C> {
SMatrix::from_fn(|i, j| self.0[mid.0[i][j]])
}
pub fn get_complex_matrix<const R: usize, const C: usize>(
&self,
cmid: ComplexMatrixID<R, C>,
) -> SMatrix<Complex<Float>, R, C> {
SMatrix::from_fn(|i, j| Complex::new(self.0[cmid.0[i][j]], self.0[cmid.1[i][j]]))
}
}
#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize)]
pub enum ParameterID {
Parameter(usize),
Constant(usize),
#[default]
Uninit,
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
pub struct ScalarID(usize);
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
pub struct ComplexScalarID(usize, usize);
#[serde_as]
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct VectorID<const R: usize>(#[serde_as(as = "[_; R]")] [usize; R]);
impl<const R: usize> Default for VectorID<R> {
fn default() -> Self {
Self([0; R])
}
}
#[serde_as]
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct ComplexVectorID<const R: usize>(
#[serde_as(as = "[_; R]")] [usize; R],
#[serde_as(as = "[_; R]")] [usize; R],
);
impl<const R: usize> Default for ComplexVectorID<R> {
fn default() -> Self {
Self([0; R], [0; R])
}
}
#[serde_as]
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct MatrixID<const R: usize, const C: usize>(
#[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
);
impl<const R: usize, const C: usize> Default for MatrixID<R, C> {
fn default() -> Self {
Self([[0; C]; R])
}
}
#[serde_as]
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct ComplexMatrixID<const R: usize, const C: usize>(
#[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
#[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
);
impl<const R: usize, const C: usize> Default for ComplexMatrixID<R, C> {
fn default() -> Self {
Self([[0; C]; R], [[0; C]; R])
}
}
impl Resources {
pub fn activate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
self.active[self
.amplitudes
.get(name.as_ref())
.ok_or(LadduError::AmplitudeNotFoundError {
name: name.as_ref().to_string(),
})?
.1] = true;
Ok(())
}
pub fn activate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
for name in names {
self.activate(name)?
}
Ok(())
}
pub fn activate_all(&mut self) {
self.active = vec![true; self.active.len()];
}
pub fn deactivate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
self.active[self
.amplitudes
.get(name.as_ref())
.ok_or(LadduError::AmplitudeNotFoundError {
name: name.as_ref().to_string(),
})?
.1] = false;
Ok(())
}
pub fn deactivate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
for name in names {
self.deactivate(name)?;
}
Ok(())
}
pub fn deactivate_all(&mut self) {
self.active = vec![false; self.active.len()];
}
pub fn isolate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
self.deactivate_all();
self.activate(name)
}
pub fn isolate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
self.deactivate_all();
self.activate_many(names)
}
pub fn register_amplitude(&mut self, name: &str) -> Result<AmplitudeID, LadduError> {
if self.amplitudes.contains_key(name) {
return Err(LadduError::RegistrationError {
name: name.to_string(),
});
}
let next_id = AmplitudeID(name.to_string(), self.amplitudes.len());
self.amplitudes.insert(name.to_string(), next_id.clone());
self.active.push(true);
Ok(next_id)
}
pub fn register_parameter(&mut self, pl: &ParameterLike) -> ParameterID {
match pl {
ParameterLike::Parameter(name) => {
let (index, _) = self.parameters.insert_full(name.to_string());
ParameterID::Parameter(index)
}
ParameterLike::Constant(value) => {
self.constants.push(*value);
ParameterID::Constant(self.constants.len() - 1)
}
ParameterLike::Uninit => panic!("Parameter was not initialized!"),
}
}
pub(crate) fn reserve_cache(&mut self, num_events: usize) {
self.caches = vec![Cache::new(self.cache_size); num_events]
}
pub fn register_scalar(&mut self, name: Option<&str>) -> ScalarID {
let first_index = if let Some(name) = name {
*self
.scalar_cache_names
.entry(name.to_string())
.or_insert_with(|| {
self.cache_size += 1;
self.cache_size - 1
})
} else {
self.cache_size += 1;
self.cache_size - 1
};
ScalarID(first_index)
}
pub fn register_complex_scalar(&mut self, name: Option<&str>) -> ComplexScalarID {
let first_index = if let Some(name) = name {
*self
.complex_scalar_cache_names
.entry(name.to_string())
.or_insert_with(|| {
self.cache_size += 2;
self.cache_size - 2
})
} else {
self.cache_size += 2;
self.cache_size - 2
};
ComplexScalarID(first_index, first_index + 1)
}
pub fn register_vector<const R: usize>(&mut self, name: Option<&str>) -> VectorID<R> {
let first_index = if let Some(name) = name {
*self
.vector_cache_names
.entry(name.to_string())
.or_insert_with(|| {
self.cache_size += R;
self.cache_size - R
})
} else {
self.cache_size += R;
self.cache_size - R
};
VectorID(array::from_fn(|i| first_index + i))
}
pub fn register_complex_vector<const R: usize>(
&mut self,
name: Option<&str>,
) -> ComplexVectorID<R> {
let first_index = if let Some(name) = name {
*self
.complex_vector_cache_names
.entry(name.to_string())
.or_insert_with(|| {
self.cache_size += R * 2;
self.cache_size - (R * 2)
})
} else {
self.cache_size += R * 2;
self.cache_size - (R * 2)
};
ComplexVectorID(
array::from_fn(|i| first_index + i),
array::from_fn(|i| (first_index + R) + i),
)
}
pub fn register_matrix<const R: usize, const C: usize>(
&mut self,
name: Option<&str>,
) -> MatrixID<R, C> {
let first_index = if let Some(name) = name {
*self
.matrix_cache_names
.entry(name.to_string())
.or_insert_with(|| {
self.cache_size += R * C;
self.cache_size - (R * C)
})
} else {
self.cache_size += R * C;
self.cache_size - (R * C)
};
MatrixID(array::from_fn(|i| {
array::from_fn(|j| first_index + i * C + j)
}))
}
pub fn register_complex_matrix<const R: usize, const C: usize>(
&mut self,
name: Option<&str>,
) -> ComplexMatrixID<R, C> {
let first_index = if let Some(name) = name {
*self
.complex_matrix_cache_names
.entry(name.to_string())
.or_insert_with(|| {
self.cache_size += 2 * R * C;
self.cache_size - (2 * R * C)
})
} else {
self.cache_size += 2 * R * C;
self.cache_size - (2 * R * C)
};
ComplexMatrixID(
array::from_fn(|i| array::from_fn(|j| first_index + i * C + j)),
array::from_fn(|i| array::from_fn(|j| (first_index + R * C) + i * C + j)),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::{Matrix2, Vector2};
use num::Complex;
#[test]
fn test_parameters() {
let parameters = vec![1.0, 2.0, 3.0];
let constants = vec![4.0, 5.0, 6.0];
let params = Parameters::new(¶meters, &constants);
assert_eq!(params.get(ParameterID::Parameter(0)), 1.0);
assert_eq!(params.get(ParameterID::Parameter(1)), 2.0);
assert_eq!(params.get(ParameterID::Parameter(2)), 3.0);
assert_eq!(params.get(ParameterID::Constant(0)), 4.0);
assert_eq!(params.get(ParameterID::Constant(1)), 5.0);
assert_eq!(params.get(ParameterID::Constant(2)), 6.0);
assert_eq!(params.len(), 3);
}
#[test]
#[should_panic(expected = "Parameter has not been registered!")]
fn test_uninit_parameter() {
let parameters = vec![1.0];
let constants = vec![1.0];
let params = Parameters::new(¶meters, &constants);
params.get(ParameterID::Uninit);
}
#[test]
fn test_resources_amplitude_management() {
let mut resources = Resources::default();
let amp1 = resources.register_amplitude("amp1").unwrap();
let amp2 = resources.register_amplitude("amp2").unwrap();
assert!(resources.active[amp1.1]);
assert!(resources.active[amp2.1]);
resources.deactivate("amp1").unwrap();
assert!(!resources.active[amp1.1]);
assert!(resources.active[amp2.1]);
resources.activate("amp1").unwrap();
assert!(resources.active[amp1.1]);
resources.deactivate_all();
assert!(!resources.active[amp1.1]);
assert!(!resources.active[amp2.1]);
resources.activate_all();
assert!(resources.active[amp1.1]);
assert!(resources.active[amp2.1]);
resources.isolate("amp1").unwrap();
assert!(resources.active[amp1.1]);
assert!(!resources.active[amp2.1]);
}
#[test]
fn test_resources_parameter_registration() {
let mut resources = Resources::default();
let param1 = resources.register_parameter(&ParameterLike::Parameter("param1".to_string()));
let const1 = resources.register_parameter(&ParameterLike::Constant(1.0));
match param1 {
ParameterID::Parameter(idx) => assert_eq!(idx, 0),
_ => panic!("Expected Parameter variant"),
}
match const1 {
ParameterID::Constant(idx) => assert_eq!(idx, 0),
_ => panic!("Expected Constant variant"),
}
}
#[test]
fn test_cache_scalar_operations() {
let mut resources = Resources::default();
let scalar1 = resources.register_scalar(Some("test_scalar"));
let scalar2 = resources.register_scalar(None);
let scalar3 = resources.register_scalar(Some("test_scalar"));
resources.reserve_cache(1);
let cache = &mut resources.caches[0];
cache.store_scalar(scalar1, 1.0);
cache.store_scalar(scalar2, 2.0);
assert_eq!(cache.get_scalar(scalar1), 1.0);
assert_eq!(cache.get_scalar(scalar2), 2.0);
assert_eq!(cache.get_scalar(scalar3), 1.0);
}
#[test]
fn test_cache_complex_operations() {
let mut resources = Resources::default();
let complex1 = resources.register_complex_scalar(Some("test_complex"));
let complex2 = resources.register_complex_scalar(None);
let complex3 = resources.register_complex_scalar(Some("test_complex"));
resources.reserve_cache(1);
let cache = &mut resources.caches[0];
let value1 = Complex::new(1.0, 2.0);
let value2 = Complex::new(3.0, 4.0);
cache.store_complex_scalar(complex1, value1);
cache.store_complex_scalar(complex2, value2);
assert_eq!(cache.get_complex_scalar(complex1), value1);
assert_eq!(cache.get_complex_scalar(complex2), value2);
assert_eq!(cache.get_complex_scalar(complex3), value1);
}
#[test]
fn test_cache_vector_operations() {
let mut resources = Resources::default();
let vector_id1: VectorID<2> = resources.register_vector(Some("test_vector"));
let vector_id2: VectorID<2> = resources.register_vector(None);
let vector_id3: VectorID<2> = resources.register_vector(Some("test_vector"));
resources.reserve_cache(1);
let cache = &mut resources.caches[0];
let value1 = Vector2::new(1.0, 2.0);
let value2 = Vector2::new(3.0, 4.0);
cache.store_vector(vector_id1, value1);
cache.store_vector(vector_id2, value2);
assert_eq!(cache.get_vector(vector_id1), value1);
assert_eq!(cache.get_vector(vector_id2), value2);
assert_eq!(cache.get_vector(vector_id3), value1);
}
#[test]
fn test_cache_complex_vector_operations() {
let mut resources = Resources::default();
let complex_vector_id1: ComplexVectorID<2> =
resources.register_complex_vector(Some("test_complex_vector"));
let complex_vector_id2: ComplexVectorID<2> = resources.register_complex_vector(None);
let complex_vector_id3: ComplexVectorID<2> =
resources.register_complex_vector(Some("test_complex_vector"));
resources.reserve_cache(1);
let cache = &mut resources.caches[0];
let value1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
let value2 = Vector2::new(Complex::new(5.0, 6.0), Complex::new(7.0, 8.0));
cache.store_complex_vector(complex_vector_id1, value1);
cache.store_complex_vector(complex_vector_id2, value2);
assert_eq!(cache.get_complex_vector(complex_vector_id1), value1);
assert_eq!(cache.get_complex_vector(complex_vector_id2), value2);
assert_eq!(cache.get_complex_vector(complex_vector_id3), value1);
}
#[test]
fn test_cache_matrix_operations() {
let mut resources = Resources::default();
let matrix_id1: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
let matrix_id2: MatrixID<2, 2> = resources.register_matrix(None);
let matrix_id3: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
resources.reserve_cache(1);
let cache = &mut resources.caches[0];
let value1 = Matrix2::new(1.0, 2.0, 3.0, 4.0);
let value2 = Matrix2::new(5.0, 6.0, 7.0, 8.0);
cache.store_matrix(matrix_id1, value1);
cache.store_matrix(matrix_id2, value2);
assert_eq!(cache.get_matrix(matrix_id1), value1);
assert_eq!(cache.get_matrix(matrix_id2), value2);
assert_eq!(cache.get_matrix(matrix_id3), value1);
}
#[test]
fn test_cache_complex_matrix_operations() {
let mut resources = Resources::default();
let complex_matrix_id1: ComplexMatrixID<2, 2> =
resources.register_complex_matrix(Some("test_complex_matrix"));
let complex_matrix_id2: ComplexMatrixID<2, 2> = resources.register_complex_matrix(None);
let complex_matrix_id3: ComplexMatrixID<2, 2> =
resources.register_complex_matrix(Some("test_complex_matrix"));
resources.reserve_cache(1);
let cache = &mut resources.caches[0];
let value1 = Matrix2::new(
Complex::new(1.0, 2.0),
Complex::new(3.0, 4.0),
Complex::new(5.0, 6.0),
Complex::new(7.0, 8.0),
);
let value2 = Matrix2::new(
Complex::new(9.0, 10.0),
Complex::new(11.0, 12.0),
Complex::new(13.0, 14.0),
Complex::new(15.0, 16.0),
);
cache.store_complex_matrix(complex_matrix_id1, value1);
cache.store_complex_matrix(complex_matrix_id2, value2);
assert_eq!(cache.get_complex_matrix(complex_matrix_id1), value1);
assert_eq!(cache.get_complex_matrix(complex_matrix_id2), value2);
assert_eq!(cache.get_complex_matrix(complex_matrix_id3), value1);
}
#[test]
#[should_panic(expected = "Parameter was not initialized!")]
fn test_uninit_parameter_registration() {
let mut resources = Resources::default();
resources.register_parameter(&ParameterLike::Uninit);
}
#[test]
fn test_duplicate_named_amplitude_registration_error() {
let mut resources = Resources::default();
assert!(resources.register_amplitude("test_amp").is_ok());
assert!(resources.register_amplitude("test_amp").is_err());
}
}