laddu_core/
resources.rs

1use std::{array, collections::HashMap};
2
3use indexmap::IndexSet;
4use nalgebra::{SMatrix, SVector};
5use num::complex::Complex64;
6use serde::{Deserialize, Serialize};
7use serde_with::serde_as;
8
9use crate::{
10    amplitudes::{AmplitudeID, ParameterLike},
11    LadduError, LadduResult,
12};
13
14/// This struct holds references to the constants and free parameters used in the fit so that they
15/// may be obtained from their corresponding [`ParameterID`].
16#[derive(Debug)]
17pub struct Parameters<'a> {
18    pub(crate) parameters: &'a [f64],
19    pub(crate) constants: &'a [f64],
20}
21
22impl<'a> Parameters<'a> {
23    /// Create a new set of [`Parameters`] from a list of floating values and a list of constant values
24    pub fn new(parameters: &'a [f64], constants: &'a [f64]) -> Self {
25        Self {
26            parameters,
27            constants,
28        }
29    }
30
31    /// Obtain a parameter value or constant value from the given [`ParameterID`].
32    pub fn get(&self, pid: ParameterID) -> f64 {
33        match pid {
34            ParameterID::Parameter(index) => self.parameters[index],
35            ParameterID::Constant(index) => self.constants[index],
36            ParameterID::Uninit => panic!("Parameter has not been registered!"),
37        }
38    }
39
40    /// The number of free parameters.
41    #[allow(clippy::len_without_is_empty)]
42    pub fn len(&self) -> usize {
43        self.parameters.len()
44    }
45}
46
47/// The main resource manager for cached values, amplitudes, parameters, and constants.
48#[derive(Default, Debug, Clone, Serialize, Deserialize)]
49pub struct Resources {
50    amplitudes: HashMap<String, AmplitudeID>,
51    /// A list indicating which amplitudes are active (using [`AmplitudeID`]s as indices)
52    pub active: Vec<bool>,
53    /// The set of all registered free parameter names across registered [`Amplitude`]s
54    pub free_parameters: IndexSet<String>,
55    /// The set of all registered fixed parameter names across registered [`Amplitude`]s
56    pub fixed_parameters: IndexSet<String>,
57    /// Values of all constants/fixed parameters across registered [`Amplitude`]s
58    pub constants: Vec<f64>,
59    /// The [`Cache`] for each [`EventData`](`crate::data::EventData`)
60    pub caches: Vec<Cache>,
61    scalar_cache_names: HashMap<String, usize>,
62    complex_scalar_cache_names: HashMap<String, usize>,
63    vector_cache_names: HashMap<String, usize>,
64    complex_vector_cache_names: HashMap<String, usize>,
65    matrix_cache_names: HashMap<String, usize>,
66    complex_matrix_cache_names: HashMap<String, usize>,
67    cache_size: usize,
68    parameter_entries: HashMap<String, ParameterEntry>,
69    pub(crate) parameter_overrides: ParameterTransform,
70}
71
72/// A single cache entry corresponding to precomputed data for a particular
73/// [`EventData`](crate::data::EventData) in a [`Dataset`](crate::data::Dataset).
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct Cache(Vec<f64>);
76impl Cache {
77    fn new(cache_size: usize) -> Self {
78        Self(vec![0.0; cache_size])
79    }
80    /// Store a scalar value with the corresponding [`ScalarID`].
81    pub fn store_scalar(&mut self, sid: ScalarID, value: f64) {
82        self.0[sid.0] = value;
83    }
84    /// Store a complex scalar value with the corresponding [`ComplexScalarID`].
85    pub fn store_complex_scalar(&mut self, csid: ComplexScalarID, value: Complex64) {
86        self.0[csid.0] = value.re;
87        self.0[csid.1] = value.im;
88    }
89    /// Store a vector with the corresponding [`VectorID`].
90    pub fn store_vector<const R: usize>(&mut self, vid: VectorID<R>, value: SVector<f64, R>) {
91        vid.0
92            .into_iter()
93            .enumerate()
94            .for_each(|(vi, i)| self.0[i] = value[vi]);
95    }
96    /// Store a complex-valued vector with the corresponding [`ComplexVectorID`].
97    pub fn store_complex_vector<const R: usize>(
98        &mut self,
99        cvid: ComplexVectorID<R>,
100        value: SVector<Complex64, R>,
101    ) {
102        cvid.0
103            .into_iter()
104            .enumerate()
105            .for_each(|(vi, i)| self.0[i] = value[vi].re);
106        cvid.1
107            .into_iter()
108            .enumerate()
109            .for_each(|(vi, i)| self.0[i] = value[vi].im);
110    }
111    /// Store a matrix with the corresponding [`MatrixID`].
112    pub fn store_matrix<const R: usize, const C: usize>(
113        &mut self,
114        mid: MatrixID<R, C>,
115        value: SMatrix<f64, R, C>,
116    ) {
117        mid.0.into_iter().enumerate().for_each(|(vi, row)| {
118            row.into_iter()
119                .enumerate()
120                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)])
121        });
122    }
123    /// Store a complex-valued matrix with the corresponding [`ComplexMatrixID`].
124    pub fn store_complex_matrix<const R: usize, const C: usize>(
125        &mut self,
126        cmid: ComplexMatrixID<R, C>,
127        value: SMatrix<Complex64, R, C>,
128    ) {
129        cmid.0.into_iter().enumerate().for_each(|(vi, row)| {
130            row.into_iter()
131                .enumerate()
132                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)].re)
133        });
134        cmid.1.into_iter().enumerate().for_each(|(vi, row)| {
135            row.into_iter()
136                .enumerate()
137                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)].im)
138        });
139    }
140    /// Retrieve a scalar value from the [`Cache`].
141    pub fn get_scalar(&self, sid: ScalarID) -> f64 {
142        self.0[sid.0]
143    }
144    /// Retrieve a complex scalar value from the [`Cache`].
145    pub fn get_complex_scalar(&self, csid: ComplexScalarID) -> Complex64 {
146        Complex64::new(self.0[csid.0], self.0[csid.1])
147    }
148    /// Retrieve a vector from the [`Cache`].
149    pub fn get_vector<const R: usize>(&self, vid: VectorID<R>) -> SVector<f64, R> {
150        SVector::from_fn(|i, _| self.0[vid.0[i]])
151    }
152    /// Retrieve a complex-valued vector from the [`Cache`].
153    pub fn get_complex_vector<const R: usize>(
154        &self,
155        cvid: ComplexVectorID<R>,
156    ) -> SVector<Complex64, R> {
157        SVector::from_fn(|i, _| Complex64::new(self.0[cvid.0[i]], self.0[cvid.1[i]]))
158    }
159    /// Retrieve a matrix from the [`Cache`].
160    pub fn get_matrix<const R: usize, const C: usize>(
161        &self,
162        mid: MatrixID<R, C>,
163    ) -> SMatrix<f64, R, C> {
164        SMatrix::from_fn(|i, j| self.0[mid.0[i][j]])
165    }
166    /// Retrieve a complex-valued matrix from the [`Cache`].
167    pub fn get_complex_matrix<const R: usize, const C: usize>(
168        &self,
169        cmid: ComplexMatrixID<R, C>,
170    ) -> SMatrix<Complex64, R, C> {
171        SMatrix::from_fn(|i, j| Complex64::new(self.0[cmid.0[i][j]], self.0[cmid.1[i][j]]))
172    }
173}
174
175/// An object which acts as a tag to refer to either a free parameter or a constant value.
176#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize)]
177pub enum ParameterID {
178    /// A free parameter.
179    Parameter(usize),
180    /// A constant value.
181    Constant(usize),
182    /// An uninitialized ID
183    #[default]
184    Uninit,
185}
186
187/// A tag for retrieving or storing a scalar value in a [`Cache`].
188#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
189pub struct ScalarID(usize);
190
191/// A tag for retrieving or storing a complex scalar value in a [`Cache`].
192#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
193pub struct ComplexScalarID(usize, usize);
194
195/// A tag for retrieving or storing a vector in a [`Cache`].
196#[serde_as]
197#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
198pub struct VectorID<const R: usize>(#[serde_as(as = "[_; R]")] [usize; R]);
199
200impl<const R: usize> Default for VectorID<R> {
201    fn default() -> Self {
202        Self([0; R])
203    }
204}
205
206/// A tag for retrieving or storing a complex-valued vector in a [`Cache`].
207#[serde_as]
208#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
209pub struct ComplexVectorID<const R: usize>(
210    #[serde_as(as = "[_; R]")] [usize; R],
211    #[serde_as(as = "[_; R]")] [usize; R],
212);
213
214impl<const R: usize> Default for ComplexVectorID<R> {
215    fn default() -> Self {
216        Self([0; R], [0; R])
217    }
218}
219
220/// A tag for retrieving or storing a matrix in a [`Cache`].
221#[serde_as]
222#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
223pub struct MatrixID<const R: usize, const C: usize>(
224    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
225);
226
227impl<const R: usize, const C: usize> Default for MatrixID<R, C> {
228    fn default() -> Self {
229        Self([[0; C]; R])
230    }
231}
232
233/// A tag for retrieving or storing a complex-valued matrix in a [`Cache`].
234#[serde_as]
235#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
236pub struct ComplexMatrixID<const R: usize, const C: usize>(
237    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
238    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
239);
240
241impl<const R: usize, const C: usize> Default for ComplexMatrixID<R, C> {
242    fn default() -> Self {
243        Self([[0; C]; R], [[0; C]; R])
244    }
245}
246
247#[derive(Clone, Debug, Serialize, Deserialize)]
248struct ParameterEntry {
249    id: ParameterID,
250    fixed: Option<f64>,
251}
252
253/// Parameter transformation instructions applied during re-registration.
254#[derive(Default, Clone, Debug, Serialize, Deserialize)]
255pub struct ParameterTransform {
256    /// Mapping from old parameter names to new names.
257    pub renames: HashMap<String, String>,
258    /// Parameters to force fixed with the provided value.
259    pub fixed: HashMap<String, f64>,
260    /// Parameters to force free (ignore any fixed value).
261    pub freed: IndexSet<String>,
262}
263
264impl ParameterTransform {
265    /// Merge two transforms, with `other` overriding `self` where keys overlap.
266    pub fn merged(&self, other: &Self) -> Self {
267        let mut merged = self.clone();
268        merged.renames.extend(other.renames.clone());
269        merged.fixed.extend(other.fixed.clone());
270        merged.freed.extend(other.freed.clone());
271        merged
272    }
273}
274
275impl Resources {
276    /// Create a new [`Resources`] instance with a parameter transform applied.
277    pub fn with_transform(transform: ParameterTransform) -> Self {
278        Self {
279            parameter_overrides: transform,
280            ..Default::default()
281        }
282    }
283
284    /// The list of free parameter names.
285    pub fn free_parameter_names(&self) -> Vec<String> {
286        self.free_parameters.iter().cloned().collect()
287    }
288
289    /// The list of fixed parameter names.
290    pub fn fixed_parameter_names(&self) -> Vec<String> {
291        self.fixed_parameters.iter().cloned().collect()
292    }
293
294    /// All parameter names (free first, then fixed).
295    pub fn parameter_names(&self) -> Vec<String> {
296        self.free_parameter_names()
297            .into_iter()
298            .chain(self.fixed_parameter_names())
299            .collect()
300    }
301
302    /// Number of free parameters.
303    pub fn n_free_parameters(&self) -> usize {
304        self.free_parameters.len()
305    }
306
307    /// Number of fixed parameters.
308    pub fn n_fixed_parameters(&self) -> usize {
309        self.fixed_parameters.len()
310    }
311
312    /// Total number of parameters.
313    pub fn n_parameters(&self) -> usize {
314        self.n_free_parameters() + self.n_fixed_parameters()
315    }
316
317    #[inline]
318    fn set_activation_state(&mut self, name: &str, active: bool) -> bool {
319        if let Some(amplitude) = self.amplitudes.get(name) {
320            self.active[amplitude.1] = active;
321            true
322        } else {
323            false
324        }
325    }
326    /// Activate an [`Amplitude`](crate::amplitudes::Amplitude) by name.
327    pub fn activate<T: AsRef<str>>(&mut self, name: T) {
328        let name_ref = name.as_ref();
329        self.set_activation_state(name_ref, true);
330    }
331    /// Activate several [`Amplitude`](crate::amplitudes::Amplitude)s by name.
332    pub fn activate_many<T: AsRef<str>>(&mut self, names: &[T]) {
333        for name in names {
334            self.activate(name);
335        }
336    }
337    /// Activate an [`Amplitude`](crate::amplitudes::Amplitude) by name, returning an error if it is missing.
338    pub fn activate_strict<T: AsRef<str>>(&mut self, name: T) -> LadduResult<()> {
339        let name_ref = name.as_ref();
340        if self.set_activation_state(name_ref, true) {
341            Ok(())
342        } else {
343            Err(LadduError::AmplitudeNotFoundError {
344                name: name_ref.to_string(),
345            })
346        }
347    }
348    /// Activate several [`Amplitude`](crate::amplitudes::Amplitude)s by name, returning an error if any are missing.
349    pub fn activate_many_strict<T: AsRef<str>>(&mut self, names: &[T]) -> LadduResult<()> {
350        for name in names {
351            self.activate_strict(name)?
352        }
353        Ok(())
354    }
355    /// Activate all registered [`Amplitude`](crate::amplitudes::Amplitude)s.
356    pub fn activate_all(&mut self) {
357        self.active = vec![true; self.active.len()];
358    }
359    /// Deactivate an [`Amplitude`](crate::amplitudes::Amplitude) by name.
360    pub fn deactivate<T: AsRef<str>>(&mut self, name: T) {
361        let name_ref = name.as_ref();
362        self.set_activation_state(name_ref, false);
363    }
364    /// Deactivate several [`Amplitude`](crate::amplitudes::Amplitude)s by name.
365    pub fn deactivate_many<T: AsRef<str>>(&mut self, names: &[T]) {
366        for name in names {
367            self.deactivate(name);
368        }
369    }
370    /// Deactivate an [`Amplitude`](crate::amplitudes::Amplitude) by name, returning an error if it is missing.
371    pub fn deactivate_strict<T: AsRef<str>>(&mut self, name: T) -> LadduResult<()> {
372        let name_ref = name.as_ref();
373        if self.set_activation_state(name_ref, false) {
374            Ok(())
375        } else {
376            Err(LadduError::AmplitudeNotFoundError {
377                name: name_ref.to_string(),
378            })
379        }
380    }
381    /// Deactivate several [`Amplitude`](crate::amplitudes::Amplitude)s by name, returning an error if any are missing.
382    pub fn deactivate_many_strict<T: AsRef<str>>(&mut self, names: &[T]) -> LadduResult<()> {
383        for name in names {
384            self.deactivate_strict(name)?;
385        }
386        Ok(())
387    }
388    /// Deactivate all registered [`Amplitude`](crate::amplitudes::Amplitude)s.
389    pub fn deactivate_all(&mut self) {
390        self.active = vec![false; self.active.len()];
391    }
392    /// Isolate an [`Amplitude`](crate::amplitudes::Amplitude) by name (deactivate the rest).
393    pub fn isolate<T: AsRef<str>>(&mut self, name: T) {
394        self.deactivate_all();
395        self.activate(name);
396    }
397    /// Isolate an [`Amplitude`](crate::amplitudes::Amplitude) by name (deactivate the rest), returning an error if it is missing.
398    pub fn isolate_strict<T: AsRef<str>>(&mut self, name: T) -> LadduResult<()> {
399        self.deactivate_all();
400        self.activate_strict(name)
401    }
402    /// Isolate several [`Amplitude`](crate::amplitudes::Amplitude)s by name (deactivate the rest).
403    pub fn isolate_many<T: AsRef<str>>(&mut self, names: &[T]) {
404        self.deactivate_all();
405        self.activate_many(names);
406    }
407    /// Isolate several [`Amplitude`](crate::amplitudes::Amplitude)s by name (deactivate the rest), returning an error if any are missing.
408    pub fn isolate_many_strict<T: AsRef<str>>(&mut self, names: &[T]) -> LadduResult<()> {
409        self.deactivate_all();
410        self.activate_many_strict(names)
411    }
412    /// Register an [`Amplitude`](crate::amplitudes::Amplitude) with the [`Resources`] manager.
413    /// This method should be called at the end of the
414    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method. The
415    /// [`Amplitude`](crate::amplitudes::Amplitude) should probably obtain a name [`String`] in its
416    /// constructor.
417    ///
418    /// # Errors
419    ///
420    /// The [`Amplitude`](crate::amplitudes::Amplitude)'s name must be unique and not already
421    /// registered, else this will return a [`RegistrationError`][LadduError::RegistrationError].
422    pub fn register_amplitude(&mut self, name: &str) -> LadduResult<AmplitudeID> {
423        if self.amplitudes.contains_key(name) {
424            return Err(LadduError::RegistrationError {
425                name: name.to_string(),
426            });
427        }
428        let next_id = AmplitudeID(name.to_string(), self.amplitudes.len());
429        self.amplitudes.insert(name.to_string(), next_id.clone());
430        self.active.push(true);
431        Ok(next_id)
432    }
433
434    /// Fetch the [`AmplitudeID`] for a previously registered amplitude by name.
435    pub fn amplitude_id(&self, name: &str) -> Option<AmplitudeID> {
436        self.amplitudes.get(name).cloned()
437    }
438
439    fn apply_transform(&self, name: &str, fixed: Option<f64>) -> (String, Option<f64>) {
440        let final_name = self
441            .parameter_overrides
442            .renames
443            .get(name)
444            .cloned()
445            .unwrap_or_else(|| name.to_string());
446        let fixed_value = if let Some(value) = self.parameter_overrides.fixed.get(name) {
447            Some(*value)
448        } else if self.parameter_overrides.freed.contains(name) {
449            None
450        } else {
451            fixed
452        };
453        (final_name, fixed_value)
454    }
455
456    /// Register a parameter. This method should be called within
457    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register). The resulting
458    /// [`ParameterID`] should be stored to retrieve the value from the [`Parameters`] wrapper.
459    ///
460    /// # Errors
461    ///
462    /// Returns an error if the parameter is unnamed, if the name is reused with incompatible
463    /// fixed/free status or fixed value, or if renaming causes a conflict.
464    pub fn register_parameter(&mut self, p: &ParameterLike) -> LadduResult<ParameterID> {
465        let base_name = p.name();
466        if base_name.is_empty() {
467            return Err(LadduError::UnregisteredParameter {
468                name: "<unnamed>".to_string(),
469                reason: "Parameter was not initialized with a name".to_string(),
470            });
471        }
472        let (final_name, fixed_value) = self.apply_transform(base_name, p.fixed);
473
474        if let Some(existing) = self.parameter_entries.get(&final_name) {
475            match (existing.fixed, fixed_value) {
476                (Some(a), Some(b)) if (a - b).abs() > f64::EPSILON => {
477                    return Err(LadduError::ParameterConflict {
478                        name: final_name,
479                        reason: "conflicting fixed values for the same parameter name".to_string(),
480                    })
481                }
482                (Some(_), None) => {
483                    return Err(LadduError::ParameterConflict {
484                        name: final_name,
485                        reason: "attempted to use a fixed parameter name as free".to_string(),
486                    })
487                }
488                (None, Some(_)) => {
489                    return Err(LadduError::ParameterConflict {
490                        name: final_name,
491                        reason: "attempted to use a free parameter name as fixed".to_string(),
492                    })
493                }
494                _ => return Ok(existing.id),
495            }
496        }
497
498        let entry = if let Some(value) = fixed_value {
499            self.fixed_parameters.insert(final_name.clone());
500            self.constants.push(value);
501            ParameterEntry {
502                id: ParameterID::Constant(self.constants.len() - 1),
503                fixed: Some(value),
504            }
505        } else {
506            let (index, _) = self.free_parameters.insert_full(final_name.clone());
507            ParameterEntry {
508                id: ParameterID::Parameter(index),
509                fixed: None,
510            }
511        };
512        self.parameter_entries.insert(final_name, entry.clone());
513        Ok(entry.id)
514    }
515    pub(crate) fn reserve_cache(&mut self, num_events: usize) {
516        self.caches = vec![Cache::new(self.cache_size); num_events]
517    }
518    /// Register a scalar with an optional name (names are unique to the [`Cache`] so two different
519    /// registrations of the same type which share a name will also share values and may overwrite
520    /// each other). This method should be called within the
521    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
522    /// resulting [`ScalarID`] should be stored to use later to retrieve the value from the [`Cache`].
523    pub fn register_scalar(&mut self, name: Option<&str>) -> ScalarID {
524        let first_index = if let Some(name) = name {
525            *self
526                .scalar_cache_names
527                .entry(name.to_string())
528                .or_insert_with(|| {
529                    self.cache_size += 1;
530                    self.cache_size - 1
531                })
532        } else {
533            self.cache_size += 1;
534            self.cache_size - 1
535        };
536        ScalarID(first_index)
537    }
538    /// Register a complex scalar with an optional name (names are unique to the [`Cache`] so two different
539    /// registrations of the same type which share a name will also share values and may overwrite
540    /// each other). This method should be called within the
541    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
542    /// resulting [`ComplexScalarID`] should be stored to use later to retrieve the value from the [`Cache`].
543    pub fn register_complex_scalar(&mut self, name: Option<&str>) -> ComplexScalarID {
544        let first_index = if let Some(name) = name {
545            *self
546                .complex_scalar_cache_names
547                .entry(name.to_string())
548                .or_insert_with(|| {
549                    self.cache_size += 2;
550                    self.cache_size - 2
551                })
552        } else {
553            self.cache_size += 2;
554            self.cache_size - 2
555        };
556        ComplexScalarID(first_index, first_index + 1)
557    }
558    /// Register a vector with an optional name (names are unique to the [`Cache`] so two different
559    /// registrations of the same type which share a name will also share values and may overwrite
560    /// each other). This method should be called within the
561    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
562    /// resulting [`VectorID`] should be stored to use later to retrieve the value from the [`Cache`].
563    pub fn register_vector<const R: usize>(&mut self, name: Option<&str>) -> VectorID<R> {
564        let first_index = if let Some(name) = name {
565            *self
566                .vector_cache_names
567                .entry(name.to_string())
568                .or_insert_with(|| {
569                    self.cache_size += R;
570                    self.cache_size - R
571                })
572        } else {
573            self.cache_size += R;
574            self.cache_size - R
575        };
576        VectorID(array::from_fn(|i| first_index + i))
577    }
578    /// Register a complex-valued vector with an optional name (names are unique to the [`Cache`] so two different
579    /// registrations of the same type which share a name will also share values and may overwrite
580    /// each other). This method should be called within the
581    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
582    /// resulting [`ComplexVectorID`] should be stored to use later to retrieve the value from the [`Cache`].
583    pub fn register_complex_vector<const R: usize>(
584        &mut self,
585        name: Option<&str>,
586    ) -> ComplexVectorID<R> {
587        let first_index = if let Some(name) = name {
588            *self
589                .complex_vector_cache_names
590                .entry(name.to_string())
591                .or_insert_with(|| {
592                    self.cache_size += R * 2;
593                    self.cache_size - (R * 2)
594                })
595        } else {
596            self.cache_size += R * 2;
597            self.cache_size - (R * 2)
598        };
599        ComplexVectorID(
600            array::from_fn(|i| first_index + i),
601            array::from_fn(|i| (first_index + R) + i),
602        )
603    }
604    /// Register a matrix with an optional name (names are unique to the [`Cache`] so two different
605    /// registrations of the same type which share a name will also share values and may overwrite
606    /// each other). This method should be called within the
607    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
608    /// resulting [`MatrixID`] should be stored to use later to retrieve the value from the [`Cache`].
609    pub fn register_matrix<const R: usize, const C: usize>(
610        &mut self,
611        name: Option<&str>,
612    ) -> MatrixID<R, C> {
613        let first_index = if let Some(name) = name {
614            *self
615                .matrix_cache_names
616                .entry(name.to_string())
617                .or_insert_with(|| {
618                    self.cache_size += R * C;
619                    self.cache_size - (R * C)
620                })
621        } else {
622            self.cache_size += R * C;
623            self.cache_size - (R * C)
624        };
625        MatrixID(array::from_fn(|i| {
626            array::from_fn(|j| first_index + i * C + j)
627        }))
628    }
629    /// Register a complex-valued matrix with an optional name (names are unique to the [`Cache`] so two different
630    /// registrations of the same type which share a name will also share values and may overwrite
631    /// each other). This method should be called within the
632    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
633    /// resulting [`ComplexMatrixID`] should be stored to use later to retrieve the value from the [`Cache`].
634    pub fn register_complex_matrix<const R: usize, const C: usize>(
635        &mut self,
636        name: Option<&str>,
637    ) -> ComplexMatrixID<R, C> {
638        let first_index = if let Some(name) = name {
639            *self
640                .complex_matrix_cache_names
641                .entry(name.to_string())
642                .or_insert_with(|| {
643                    self.cache_size += 2 * R * C;
644                    self.cache_size - (2 * R * C)
645                })
646        } else {
647            self.cache_size += 2 * R * C;
648            self.cache_size - (2 * R * C)
649        };
650        ComplexMatrixID(
651            array::from_fn(|i| array::from_fn(|j| first_index + i * C + j)),
652            array::from_fn(|i| array::from_fn(|j| (first_index + R * C) + i * C + j)),
653        )
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660    use nalgebra::{Matrix2, Vector2};
661    use num::complex::Complex64;
662
663    #[test]
664    fn test_parameters() {
665        let parameters = vec![1.0, 2.0, 3.0];
666        let constants = vec![4.0, 5.0, 6.0];
667        let params = Parameters::new(&parameters, &constants);
668
669        assert_eq!(params.get(ParameterID::Parameter(0)), 1.0);
670        assert_eq!(params.get(ParameterID::Parameter(1)), 2.0);
671        assert_eq!(params.get(ParameterID::Parameter(2)), 3.0);
672        assert_eq!(params.get(ParameterID::Constant(0)), 4.0);
673        assert_eq!(params.get(ParameterID::Constant(1)), 5.0);
674        assert_eq!(params.get(ParameterID::Constant(2)), 6.0);
675        assert_eq!(params.len(), 3);
676    }
677
678    #[test]
679    #[should_panic(expected = "Parameter has not been registered!")]
680    fn test_uninit_parameter() {
681        let parameters = vec![1.0];
682        let constants = vec![1.0];
683        let params = Parameters::new(&parameters, &constants);
684        params.get(ParameterID::Uninit);
685    }
686
687    #[test]
688    fn test_resources_amplitude_management() {
689        let mut resources = Resources::default();
690
691        let amp1 = resources.register_amplitude("amp1").unwrap();
692        let amp2 = resources.register_amplitude("amp2").unwrap();
693
694        assert!(resources.active[amp1.1]);
695        assert!(resources.active[amp2.1]);
696
697        resources.deactivate_strict("amp1").unwrap();
698        assert!(!resources.active[amp1.1]);
699        assert!(resources.active[amp2.1]);
700
701        resources.activate_strict("amp1").unwrap();
702        assert!(resources.active[amp1.1]);
703
704        resources.deactivate_all();
705        assert!(!resources.active[amp1.1]);
706        assert!(!resources.active[amp2.1]);
707
708        resources.activate_all();
709        assert!(resources.active[amp1.1]);
710        assert!(resources.active[amp2.1]);
711
712        resources.isolate_strict("amp1").unwrap();
713        assert!(resources.active[amp1.1]);
714        assert!(!resources.active[amp2.1]);
715    }
716
717    #[test]
718    fn test_resources_parameter_registration() {
719        let mut resources = Resources::default();
720
721        let param1 = resources
722            .register_parameter(&ParameterLike::free("param1"))
723            .unwrap();
724        let const1 = resources
725            .register_parameter(&ParameterLike::fixed("const1", 1.0))
726            .unwrap();
727
728        match param1 {
729            ParameterID::Parameter(idx) => assert_eq!(idx, 0),
730            _ => panic!("Expected Parameter variant"),
731        }
732
733        match const1 {
734            ParameterID::Constant(idx) => assert_eq!(idx, 0),
735            _ => panic!("Expected Constant variant"),
736        }
737    }
738
739    #[test]
740    fn test_cache_scalar_operations() {
741        let mut resources = Resources::default();
742
743        let scalar1 = resources.register_scalar(Some("test_scalar"));
744        let scalar2 = resources.register_scalar(None);
745        let scalar3 = resources.register_scalar(Some("test_scalar"));
746
747        resources.reserve_cache(1);
748        let cache = &mut resources.caches[0];
749
750        cache.store_scalar(scalar1, 1.0);
751        cache.store_scalar(scalar2, 2.0);
752
753        assert_eq!(cache.get_scalar(scalar1), 1.0);
754        assert_eq!(cache.get_scalar(scalar2), 2.0);
755        assert_eq!(cache.get_scalar(scalar3), 1.0);
756    }
757
758    #[test]
759    fn test_cache_complex_operations() {
760        let mut resources = Resources::default();
761
762        let complex1 = resources.register_complex_scalar(Some("test_complex"));
763        let complex2 = resources.register_complex_scalar(None);
764        let complex3 = resources.register_complex_scalar(Some("test_complex"));
765
766        resources.reserve_cache(1);
767        let cache = &mut resources.caches[0];
768
769        let value1 = Complex64::new(1.0, 2.0);
770        let value2 = Complex64::new(3.0, 4.0);
771        cache.store_complex_scalar(complex1, value1);
772        cache.store_complex_scalar(complex2, value2);
773
774        assert_eq!(cache.get_complex_scalar(complex1), value1);
775        assert_eq!(cache.get_complex_scalar(complex2), value2);
776        assert_eq!(cache.get_complex_scalar(complex3), value1);
777    }
778
779    #[test]
780    fn test_cache_vector_operations() {
781        let mut resources = Resources::default();
782
783        let vector_id1: VectorID<2> = resources.register_vector(Some("test_vector"));
784        let vector_id2: VectorID<2> = resources.register_vector(None);
785        let vector_id3: VectorID<2> = resources.register_vector(Some("test_vector"));
786
787        resources.reserve_cache(1);
788        let cache = &mut resources.caches[0];
789
790        let value1 = Vector2::new(1.0, 2.0);
791        let value2 = Vector2::new(3.0, 4.0);
792        cache.store_vector(vector_id1, value1);
793        cache.store_vector(vector_id2, value2);
794
795        assert_eq!(cache.get_vector(vector_id1), value1);
796        assert_eq!(cache.get_vector(vector_id2), value2);
797        assert_eq!(cache.get_vector(vector_id3), value1);
798    }
799
800    #[test]
801    fn test_cache_complex_vector_operations() {
802        let mut resources = Resources::default();
803
804        let complex_vector_id1: ComplexVectorID<2> =
805            resources.register_complex_vector(Some("test_complex_vector"));
806        let complex_vector_id2: ComplexVectorID<2> = resources.register_complex_vector(None);
807        let complex_vector_id3: ComplexVectorID<2> =
808            resources.register_complex_vector(Some("test_complex_vector"));
809
810        resources.reserve_cache(1);
811        let cache = &mut resources.caches[0];
812
813        let value1 = Vector2::new(Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0));
814        let value2 = Vector2::new(Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0));
815        cache.store_complex_vector(complex_vector_id1, value1);
816        cache.store_complex_vector(complex_vector_id2, value2);
817
818        assert_eq!(cache.get_complex_vector(complex_vector_id1), value1);
819        assert_eq!(cache.get_complex_vector(complex_vector_id2), value2);
820        assert_eq!(cache.get_complex_vector(complex_vector_id3), value1);
821    }
822
823    #[test]
824    fn test_cache_matrix_operations() {
825        let mut resources = Resources::default();
826
827        let matrix_id1: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
828        let matrix_id2: MatrixID<2, 2> = resources.register_matrix(None);
829        let matrix_id3: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
830
831        resources.reserve_cache(1);
832        let cache = &mut resources.caches[0];
833
834        let value1 = Matrix2::new(1.0, 2.0, 3.0, 4.0);
835        let value2 = Matrix2::new(5.0, 6.0, 7.0, 8.0);
836        cache.store_matrix(matrix_id1, value1);
837        cache.store_matrix(matrix_id2, value2);
838
839        assert_eq!(cache.get_matrix(matrix_id1), value1);
840        assert_eq!(cache.get_matrix(matrix_id2), value2);
841        assert_eq!(cache.get_matrix(matrix_id3), value1);
842    }
843
844    #[test]
845    fn test_cache_complex_matrix_operations() {
846        let mut resources = Resources::default();
847
848        let complex_matrix_id1: ComplexMatrixID<2, 2> =
849            resources.register_complex_matrix(Some("test_complex_matrix"));
850        let complex_matrix_id2: ComplexMatrixID<2, 2> = resources.register_complex_matrix(None);
851        let complex_matrix_id3: ComplexMatrixID<2, 2> =
852            resources.register_complex_matrix(Some("test_complex_matrix"));
853
854        resources.reserve_cache(1);
855        let cache = &mut resources.caches[0];
856
857        let value1 = Matrix2::new(
858            Complex64::new(1.0, 2.0),
859            Complex64::new(3.0, 4.0),
860            Complex64::new(5.0, 6.0),
861            Complex64::new(7.0, 8.0),
862        );
863        let value2 = Matrix2::new(
864            Complex64::new(9.0, 10.0),
865            Complex64::new(11.0, 12.0),
866            Complex64::new(13.0, 14.0),
867            Complex64::new(15.0, 16.0),
868        );
869        cache.store_complex_matrix(complex_matrix_id1, value1);
870        cache.store_complex_matrix(complex_matrix_id2, value2);
871
872        assert_eq!(cache.get_complex_matrix(complex_matrix_id1), value1);
873        assert_eq!(cache.get_complex_matrix(complex_matrix_id2), value2);
874        assert_eq!(cache.get_complex_matrix(complex_matrix_id3), value1);
875    }
876
877    #[test]
878    fn test_uninit_parameter_registration() {
879        let mut resources = Resources::default();
880        let result = resources.register_parameter(&ParameterLike::uninit());
881        assert!(result.is_err());
882    }
883
884    #[test]
885    fn test_duplicate_named_amplitude_registration_error() {
886        let mut resources = Resources::default();
887        assert!(resources.register_amplitude("test_amp").is_ok());
888        assert!(resources.register_amplitude("test_amp").is_err());
889    }
890}