Skip to main content

laddu_core/
resources.rs

1use std::{array, collections::HashMap};
2
3use indexmap::IndexSet;
4use nalgebra::{SMatrix, SVector};
5use num::complex::Complex64;
6use serde::{Deserialize, Serialize};
7use serde_with::serde_as;
8
9use crate::{
10    amplitudes::{AmplitudeID, ParameterLike},
11    parameter_manager::ParameterTransform,
12    LadduError, LadduResult,
13};
14
15/// This struct holds references to the constants and free parameters used in the fit so that they
16/// may be obtained from their corresponding [`ParameterID`].
17#[derive(Debug)]
18pub struct Parameters<'a> {
19    pub(crate) parameters: &'a [f64],
20    pub(crate) constants: &'a [f64],
21}
22
23impl<'a> Parameters<'a> {
24    /// Create a new set of [`Parameters`] from a list of floating values and a list of constant values
25    pub fn new(parameters: &'a [f64], constants: &'a [f64]) -> Self {
26        Self {
27            parameters,
28            constants,
29        }
30    }
31
32    /// Obtain a parameter value or constant value from the given [`ParameterID`].
33    pub fn get(&self, pid: ParameterID) -> f64 {
34        match pid {
35            ParameterID::Parameter(index) => self.parameters[index],
36            ParameterID::Constant(index) => self.constants[index],
37            ParameterID::Uninit => panic!("Parameter has not been registered!"),
38        }
39    }
40
41    /// The number of free parameters.
42    #[allow(clippy::len_without_is_empty)]
43    pub fn len(&self) -> usize {
44        self.parameters.len()
45    }
46}
47
48/// The main resource manager for cached values, amplitudes, parameters, and constants.
49#[derive(Default, Debug, Clone, Serialize, Deserialize)]
50pub struct Resources {
51    amplitudes: HashMap<String, AmplitudeID>,
52    /// A list indicating which amplitudes are active (using [`AmplitudeID`]s as indices)
53    pub active: Vec<bool>,
54    #[serde(default)]
55    active_indices: Vec<usize>,
56    /// The set of all registered free parameter names across registered [`Amplitude`]s
57    pub free_parameters: IndexSet<String>,
58    /// The set of all registered fixed parameter names across registered [`Amplitude`]s
59    pub fixed_parameters: IndexSet<String>,
60    /// Values of all constants/fixed parameters across registered [`Amplitude`]s
61    pub constants: Vec<f64>,
62    /// The [`Cache`] for each [`EventData`](`crate::data::EventData`)
63    pub caches: Vec<Cache>,
64    scalar_cache_names: HashMap<String, usize>,
65    complex_scalar_cache_names: HashMap<String, usize>,
66    vector_cache_names: HashMap<String, usize>,
67    complex_vector_cache_names: HashMap<String, usize>,
68    matrix_cache_names: HashMap<String, usize>,
69    complex_matrix_cache_names: HashMap<String, usize>,
70    cache_size: usize,
71    parameter_entries: HashMap<String, ParameterEntry>,
72    pub(crate) parameter_overrides: ParameterTransform,
73}
74
75/// A single cache entry corresponding to precomputed data for a particular
76/// [`EventData`](crate::data::EventData) in a [`Dataset`](crate::data::Dataset).
77#[derive(Clone, Debug, Serialize, Deserialize)]
78pub struct Cache(Vec<f64>);
79impl Cache {
80    fn new(cache_size: usize) -> Self {
81        Self(vec![0.0; cache_size])
82    }
83    /// Store a scalar value with the corresponding [`ScalarID`].
84    pub fn store_scalar(&mut self, sid: ScalarID, value: f64) {
85        self.0[sid.0] = value;
86    }
87    /// Store a complex scalar value with the corresponding [`ComplexScalarID`].
88    pub fn store_complex_scalar(&mut self, csid: ComplexScalarID, value: Complex64) {
89        self.0[csid.0] = value.re;
90        self.0[csid.1] = value.im;
91    }
92    /// Store a vector with the corresponding [`VectorID`].
93    pub fn store_vector<const R: usize>(&mut self, vid: VectorID<R>, value: SVector<f64, R>) {
94        vid.0
95            .into_iter()
96            .enumerate()
97            .for_each(|(vi, i)| self.0[i] = value[vi]);
98    }
99    /// Store a complex-valued vector with the corresponding [`ComplexVectorID`].
100    pub fn store_complex_vector<const R: usize>(
101        &mut self,
102        cvid: ComplexVectorID<R>,
103        value: SVector<Complex64, R>,
104    ) {
105        cvid.0
106            .into_iter()
107            .enumerate()
108            .for_each(|(vi, i)| self.0[i] = value[vi].re);
109        cvid.1
110            .into_iter()
111            .enumerate()
112            .for_each(|(vi, i)| self.0[i] = value[vi].im);
113    }
114    /// Store a matrix with the corresponding [`MatrixID`].
115    pub fn store_matrix<const R: usize, const C: usize>(
116        &mut self,
117        mid: MatrixID<R, C>,
118        value: SMatrix<f64, R, C>,
119    ) {
120        mid.0.into_iter().enumerate().for_each(|(vi, row)| {
121            row.into_iter()
122                .enumerate()
123                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)])
124        });
125    }
126    /// Store a complex-valued matrix with the corresponding [`ComplexMatrixID`].
127    pub fn store_complex_matrix<const R: usize, const C: usize>(
128        &mut self,
129        cmid: ComplexMatrixID<R, C>,
130        value: SMatrix<Complex64, R, C>,
131    ) {
132        cmid.0.into_iter().enumerate().for_each(|(vi, row)| {
133            row.into_iter()
134                .enumerate()
135                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)].re)
136        });
137        cmid.1.into_iter().enumerate().for_each(|(vi, row)| {
138            row.into_iter()
139                .enumerate()
140                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)].im)
141        });
142    }
143    /// Retrieve a scalar value from the [`Cache`].
144    pub fn get_scalar(&self, sid: ScalarID) -> f64 {
145        self.0[sid.0]
146    }
147    /// Retrieve a complex scalar value from the [`Cache`].
148    pub fn get_complex_scalar(&self, csid: ComplexScalarID) -> Complex64 {
149        Complex64::new(self.0[csid.0], self.0[csid.1])
150    }
151    /// Retrieve a vector from the [`Cache`].
152    pub fn get_vector<const R: usize>(&self, vid: VectorID<R>) -> SVector<f64, R> {
153        SVector::from_fn(|i, _| self.0[vid.0[i]])
154    }
155    /// Retrieve a complex-valued vector from the [`Cache`].
156    pub fn get_complex_vector<const R: usize>(
157        &self,
158        cvid: ComplexVectorID<R>,
159    ) -> SVector<Complex64, R> {
160        SVector::from_fn(|i, _| Complex64::new(self.0[cvid.0[i]], self.0[cvid.1[i]]))
161    }
162    /// Retrieve a matrix from the [`Cache`].
163    pub fn get_matrix<const R: usize, const C: usize>(
164        &self,
165        mid: MatrixID<R, C>,
166    ) -> SMatrix<f64, R, C> {
167        SMatrix::from_fn(|i, j| self.0[mid.0[i][j]])
168    }
169    /// Retrieve a complex-valued matrix from the [`Cache`].
170    pub fn get_complex_matrix<const R: usize, const C: usize>(
171        &self,
172        cmid: ComplexMatrixID<R, C>,
173    ) -> SMatrix<Complex64, R, C> {
174        SMatrix::from_fn(|i, j| Complex64::new(self.0[cmid.0[i][j]], self.0[cmid.1[i][j]]))
175    }
176}
177
178/// An object which acts as a tag to refer to either a free parameter or a constant value.
179#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize)]
180pub enum ParameterID {
181    /// A free parameter.
182    Parameter(usize),
183    /// A constant value.
184    Constant(usize),
185    /// An uninitialized ID
186    #[default]
187    Uninit,
188}
189
190/// A tag for retrieving or storing a scalar value in a [`Cache`].
191#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
192pub struct ScalarID(usize);
193
194/// A tag for retrieving or storing a complex scalar value in a [`Cache`].
195#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
196pub struct ComplexScalarID(usize, usize);
197
198/// A tag for retrieving or storing a vector in a [`Cache`].
199#[serde_as]
200#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
201pub struct VectorID<const R: usize>(#[serde_as(as = "[_; R]")] [usize; R]);
202
203impl<const R: usize> Default for VectorID<R> {
204    fn default() -> Self {
205        Self([0; R])
206    }
207}
208
209/// A tag for retrieving or storing a complex-valued vector in a [`Cache`].
210#[serde_as]
211#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
212pub struct ComplexVectorID<const R: usize>(
213    #[serde_as(as = "[_; R]")] [usize; R],
214    #[serde_as(as = "[_; R]")] [usize; R],
215);
216
217impl<const R: usize> Default for ComplexVectorID<R> {
218    fn default() -> Self {
219        Self([0; R], [0; R])
220    }
221}
222
223/// A tag for retrieving or storing a matrix in a [`Cache`].
224#[serde_as]
225#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
226pub struct MatrixID<const R: usize, const C: usize>(
227    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
228);
229
230impl<const R: usize, const C: usize> Default for MatrixID<R, C> {
231    fn default() -> Self {
232        Self([[0; C]; R])
233    }
234}
235
236/// A tag for retrieving or storing a complex-valued matrix in a [`Cache`].
237#[serde_as]
238#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
239pub struct ComplexMatrixID<const R: usize, const C: usize>(
240    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
241    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
242);
243
244impl<const R: usize, const C: usize> Default for ComplexMatrixID<R, C> {
245    fn default() -> Self {
246        Self([[0; C]; R], [[0; C]; R])
247    }
248}
249
250#[derive(Clone, Debug, Serialize, Deserialize)]
251struct ParameterEntry {
252    id: ParameterID,
253    fixed: Option<f64>,
254}
255
256impl Resources {
257    /// Create a new [`Resources`] instance with a parameter transform applied.
258    pub fn with_transform(transform: ParameterTransform) -> Self {
259        Self {
260            parameter_overrides: transform,
261            ..Default::default()
262        }
263    }
264
265    /// The list of free parameter names.
266    pub fn free_parameter_names(&self) -> Vec<String> {
267        self.free_parameters.iter().cloned().collect()
268    }
269
270    /// The list of fixed parameter names.
271    pub fn fixed_parameter_names(&self) -> Vec<String> {
272        self.fixed_parameters.iter().cloned().collect()
273    }
274
275    /// Map from fixed parameter names to their values.
276    pub fn fixed_parameter_values(&self) -> HashMap<String, f64> {
277        self.parameter_entries
278            .iter()
279            .filter_map(|(name, entry)| entry.fixed.map(|value| (name.clone(), value)))
280            .collect()
281    }
282
283    /// All parameter names (free first, then fixed).
284    pub fn parameter_names(&self) -> Vec<String> {
285        self.free_parameter_names()
286            .into_iter()
287            .chain(self.fixed_parameter_names())
288            .collect()
289    }
290
291    /// Number of free parameters.
292    pub fn n_free_parameters(&self) -> usize {
293        self.free_parameters.len()
294    }
295
296    /// Number of fixed parameters.
297    pub fn n_fixed_parameters(&self) -> usize {
298        self.fixed_parameters.len()
299    }
300
301    /// Total number of parameters.
302    pub fn n_parameters(&self) -> usize {
303        self.n_free_parameters() + self.n_fixed_parameters()
304    }
305
306    fn rebuild_active_indices(&mut self) {
307        self.active_indices.clear();
308        self.active_indices.extend(
309            self.active
310                .iter()
311                .enumerate()
312                .filter_map(|(idx, &is_active)| if is_active { Some(idx) } else { None }),
313        );
314    }
315
316    pub(crate) fn refresh_active_indices(&mut self) {
317        self.rebuild_active_indices();
318    }
319
320    /// Return the indices of active amplitudes.
321    pub fn active_indices(&self) -> &[usize] {
322        &self.active_indices
323    }
324
325    #[inline]
326    fn set_activation_state(&mut self, name: &str, active: bool) -> Option<bool> {
327        self.amplitudes.get(name).map(|amplitude| {
328            let idx = amplitude.1;
329            let changed = self.active[idx] != active;
330            self.active[idx] = active;
331            changed
332        })
333    }
334    /// Activate an [`Amplitude`](crate::amplitudes::Amplitude) by name.
335    pub fn activate<T: AsRef<str>>(&mut self, name: T) {
336        if self
337            .set_activation_state(name.as_ref(), true)
338            .unwrap_or(false)
339        {
340            self.rebuild_active_indices();
341        }
342    }
343    /// Activate several [`Amplitude`](crate::amplitudes::Amplitude)s by name.
344    pub fn activate_many<T: AsRef<str>>(&mut self, names: &[T]) {
345        let mut changed = false;
346        for name in names {
347            if self
348                .set_activation_state(name.as_ref(), true)
349                .unwrap_or(false)
350            {
351                changed = true;
352            }
353        }
354        if changed {
355            self.rebuild_active_indices();
356        }
357    }
358    /// Activate an [`Amplitude`](crate::amplitudes::Amplitude) by name, returning an error if it is missing.
359    pub fn activate_strict<T: AsRef<str>>(&mut self, name: T) -> LadduResult<()> {
360        let name_ref = name.as_ref();
361        match self.set_activation_state(name_ref, true) {
362            Some(changed) => {
363                if changed {
364                    self.rebuild_active_indices();
365                }
366                Ok(())
367            }
368            None => Err(LadduError::AmplitudeNotFoundError {
369                name: name_ref.to_string(),
370            }),
371        }
372    }
373    /// Activate several [`Amplitude`](crate::amplitudes::Amplitude)s by name, returning an error if any are missing.
374    pub fn activate_many_strict<T: AsRef<str>>(&mut self, names: &[T]) -> LadduResult<()> {
375        let mut changed = false;
376        for name in names {
377            let name_ref = name.as_ref();
378            match self.set_activation_state(name_ref, true) {
379                Some(state_changed) => {
380                    if state_changed {
381                        changed = true;
382                    }
383                }
384                None => {
385                    return Err(LadduError::AmplitudeNotFoundError {
386                        name: name_ref.to_string(),
387                    })
388                }
389            }
390        }
391        if changed {
392            self.rebuild_active_indices();
393        }
394        Ok(())
395    }
396    /// Activate all registered [`Amplitude`](crate::amplitudes::Amplitude)s.
397    pub fn activate_all(&mut self) {
398        let mut changed = false;
399        for active in self.active.iter_mut() {
400            if !*active {
401                *active = true;
402                changed = true;
403            }
404        }
405        if changed {
406            self.rebuild_active_indices();
407        }
408    }
409    /// Deactivate an [`Amplitude`](crate::amplitudes::Amplitude) by name.
410    pub fn deactivate<T: AsRef<str>>(&mut self, name: T) {
411        if self
412            .set_activation_state(name.as_ref(), false)
413            .unwrap_or(false)
414        {
415            self.rebuild_active_indices();
416        }
417    }
418    /// Deactivate several [`Amplitude`](crate::amplitudes::Amplitude)s by name.
419    pub fn deactivate_many<T: AsRef<str>>(&mut self, names: &[T]) {
420        let mut changed = false;
421        for name in names {
422            if self
423                .set_activation_state(name.as_ref(), false)
424                .unwrap_or(false)
425            {
426                changed = true;
427            }
428        }
429        if changed {
430            self.rebuild_active_indices();
431        }
432    }
433    /// Deactivate an [`Amplitude`](crate::amplitudes::Amplitude) by name, returning an error if it is missing.
434    pub fn deactivate_strict<T: AsRef<str>>(&mut self, name: T) -> LadduResult<()> {
435        let name_ref = name.as_ref();
436        match self.set_activation_state(name_ref, false) {
437            Some(changed) => {
438                if changed {
439                    self.rebuild_active_indices();
440                }
441                Ok(())
442            }
443            None => Err(LadduError::AmplitudeNotFoundError {
444                name: name_ref.to_string(),
445            }),
446        }
447    }
448    /// Deactivate several [`Amplitude`](crate::amplitudes::Amplitude)s by name, returning an error if any are missing.
449    pub fn deactivate_many_strict<T: AsRef<str>>(&mut self, names: &[T]) -> LadduResult<()> {
450        let mut changed = false;
451        for name in names {
452            let name_ref = name.as_ref();
453            match self.set_activation_state(name_ref, false) {
454                Some(state_changed) => {
455                    if state_changed {
456                        changed = true;
457                    }
458                }
459                None => {
460                    return Err(LadduError::AmplitudeNotFoundError {
461                        name: name_ref.to_string(),
462                    })
463                }
464            }
465        }
466        if changed {
467            self.rebuild_active_indices();
468        }
469        Ok(())
470    }
471    /// Deactivate all registered [`Amplitude`](crate::amplitudes::Amplitude)s.
472    pub fn deactivate_all(&mut self) {
473        let mut changed = false;
474        for active in self.active.iter_mut() {
475            if *active {
476                *active = false;
477                changed = true;
478            }
479        }
480        if changed {
481            self.rebuild_active_indices();
482        }
483    }
484    /// Isolate an [`Amplitude`](crate::amplitudes::Amplitude) by name (deactivate the rest).
485    pub fn isolate<T: AsRef<str>>(&mut self, name: T) {
486        self.deactivate_all();
487        self.activate(name);
488    }
489    /// Isolate an [`Amplitude`](crate::amplitudes::Amplitude) by name (deactivate the rest), returning an error if it is missing.
490    pub fn isolate_strict<T: AsRef<str>>(&mut self, name: T) -> LadduResult<()> {
491        self.deactivate_all();
492        self.activate_strict(name)
493    }
494    /// Isolate several [`Amplitude`](crate::amplitudes::Amplitude)s by name (deactivate the rest).
495    pub fn isolate_many<T: AsRef<str>>(&mut self, names: &[T]) {
496        self.deactivate_all();
497        self.activate_many(names);
498    }
499    /// Isolate several [`Amplitude`](crate::amplitudes::Amplitude)s by name (deactivate the rest), returning an error if any are missing.
500    pub fn isolate_many_strict<T: AsRef<str>>(&mut self, names: &[T]) -> LadduResult<()> {
501        self.deactivate_all();
502        self.activate_many_strict(names)
503    }
504    /// Register an [`Amplitude`](crate::amplitudes::Amplitude) with the [`Resources`] manager.
505    /// This method should be called at the end of the
506    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method. The
507    /// [`Amplitude`](crate::amplitudes::Amplitude) should probably obtain a name [`String`] in its
508    /// constructor.
509    ///
510    /// # Errors
511    ///
512    /// The [`Amplitude`](crate::amplitudes::Amplitude)'s name must be unique and not already
513    /// registered, else this will return a [`RegistrationError`][LadduError::RegistrationError].
514    pub fn register_amplitude(&mut self, name: &str) -> LadduResult<AmplitudeID> {
515        if self.amplitudes.contains_key(name) {
516            return Err(LadduError::RegistrationError {
517                name: name.to_string(),
518            });
519        }
520        let next_id = AmplitudeID(name.to_string(), self.amplitudes.len());
521        self.amplitudes.insert(name.to_string(), next_id.clone());
522        self.active.push(true);
523        self.rebuild_active_indices();
524        Ok(next_id)
525    }
526
527    /// Fetch the [`AmplitudeID`] for a previously registered amplitude by name.
528    pub fn amplitude_id(&self, name: &str) -> Option<AmplitudeID> {
529        self.amplitudes.get(name).cloned()
530    }
531
532    fn apply_transform(&self, name: &str, fixed: Option<f64>) -> (String, Option<f64>) {
533        let final_name = self
534            .parameter_overrides
535            .renames
536            .get(name)
537            .cloned()
538            .unwrap_or_else(|| name.to_string());
539        let fixed_value = if let Some(value) = self.parameter_overrides.fixed.get(name) {
540            Some(*value)
541        } else if self.parameter_overrides.freed.contains(name) {
542            None
543        } else {
544            fixed
545        };
546        (final_name, fixed_value)
547    }
548
549    /// Register a parameter. This method should be called within
550    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register). The resulting
551    /// [`ParameterID`] should be stored to retrieve the value from the [`Parameters`] wrapper.
552    ///
553    /// # Errors
554    ///
555    /// Returns an error if the parameter is unnamed, if the name is reused with incompatible
556    /// fixed/free status or fixed value, or if renaming causes a conflict.
557    pub fn register_parameter(&mut self, p: &ParameterLike) -> LadduResult<ParameterID> {
558        let base_name = p.name();
559        if base_name.is_empty() {
560            return Err(LadduError::UnregisteredParameter {
561                name: "<unnamed>".to_string(),
562                reason: "Parameter was not initialized with a name".to_string(),
563            });
564        }
565        let (final_name, fixed_value) = self.apply_transform(base_name, p.fixed);
566
567        if let Some(existing) = self.parameter_entries.get(&final_name) {
568            match (existing.fixed, fixed_value) {
569                (Some(a), Some(b)) if (a - b).abs() > f64::EPSILON => {
570                    return Err(LadduError::ParameterConflict {
571                        name: final_name,
572                        reason: "conflicting fixed values for the same parameter name".to_string(),
573                    })
574                }
575                (Some(_), None) => {
576                    return Err(LadduError::ParameterConflict {
577                        name: final_name,
578                        reason: "attempted to use a fixed parameter name as free".to_string(),
579                    })
580                }
581                (None, Some(_)) => {
582                    return Err(LadduError::ParameterConflict {
583                        name: final_name,
584                        reason: "attempted to use a free parameter name as fixed".to_string(),
585                    })
586                }
587                _ => return Ok(existing.id),
588            }
589        }
590
591        let entry = if let Some(value) = fixed_value {
592            self.fixed_parameters.insert(final_name.clone());
593            self.constants.push(value);
594            ParameterEntry {
595                id: ParameterID::Constant(self.constants.len() - 1),
596                fixed: Some(value),
597            }
598        } else {
599            let (index, _) = self.free_parameters.insert_full(final_name.clone());
600            ParameterEntry {
601                id: ParameterID::Parameter(index),
602                fixed: None,
603            }
604        };
605        self.parameter_entries.insert(final_name, entry.clone());
606        Ok(entry.id)
607    }
608    pub(crate) fn reserve_cache(&mut self, num_events: usize) {
609        self.caches = vec![Cache::new(self.cache_size); num_events]
610    }
611    /// Register a scalar with an optional name (names are unique to the [`Cache`] so two different
612    /// registrations of the same type which share a name will also share values and may overwrite
613    /// each other). This method should be called within the
614    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
615    /// resulting [`ScalarID`] should be stored to use later to retrieve the value from the [`Cache`].
616    pub fn register_scalar(&mut self, name: Option<&str>) -> ScalarID {
617        let first_index = if let Some(name) = name {
618            *self
619                .scalar_cache_names
620                .entry(name.to_string())
621                .or_insert_with(|| {
622                    self.cache_size += 1;
623                    self.cache_size - 1
624                })
625        } else {
626            self.cache_size += 1;
627            self.cache_size - 1
628        };
629        ScalarID(first_index)
630    }
631    /// Register a complex scalar with an optional name (names are unique to the [`Cache`] so two different
632    /// registrations of the same type which share a name will also share values and may overwrite
633    /// each other). This method should be called within the
634    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
635    /// resulting [`ComplexScalarID`] should be stored to use later to retrieve the value from the [`Cache`].
636    pub fn register_complex_scalar(&mut self, name: Option<&str>) -> ComplexScalarID {
637        let first_index = if let Some(name) = name {
638            *self
639                .complex_scalar_cache_names
640                .entry(name.to_string())
641                .or_insert_with(|| {
642                    self.cache_size += 2;
643                    self.cache_size - 2
644                })
645        } else {
646            self.cache_size += 2;
647            self.cache_size - 2
648        };
649        ComplexScalarID(first_index, first_index + 1)
650    }
651    /// Register a vector with an optional name (names are unique to the [`Cache`] so two different
652    /// registrations of the same type which share a name will also share values and may overwrite
653    /// each other). This method should be called within the
654    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
655    /// resulting [`VectorID`] should be stored to use later to retrieve the value from the [`Cache`].
656    pub fn register_vector<const R: usize>(&mut self, name: Option<&str>) -> VectorID<R> {
657        let first_index = if let Some(name) = name {
658            *self
659                .vector_cache_names
660                .entry(name.to_string())
661                .or_insert_with(|| {
662                    self.cache_size += R;
663                    self.cache_size - R
664                })
665        } else {
666            self.cache_size += R;
667            self.cache_size - R
668        };
669        VectorID(array::from_fn(|i| first_index + i))
670    }
671    /// Register a complex-valued vector with an optional name (names are unique to the [`Cache`] so two different
672    /// registrations of the same type which share a name will also share values and may overwrite
673    /// each other). This method should be called within the
674    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
675    /// resulting [`ComplexVectorID`] should be stored to use later to retrieve the value from the [`Cache`].
676    pub fn register_complex_vector<const R: usize>(
677        &mut self,
678        name: Option<&str>,
679    ) -> ComplexVectorID<R> {
680        let first_index = if let Some(name) = name {
681            *self
682                .complex_vector_cache_names
683                .entry(name.to_string())
684                .or_insert_with(|| {
685                    self.cache_size += R * 2;
686                    self.cache_size - (R * 2)
687                })
688        } else {
689            self.cache_size += R * 2;
690            self.cache_size - (R * 2)
691        };
692        ComplexVectorID(
693            array::from_fn(|i| first_index + i),
694            array::from_fn(|i| (first_index + R) + i),
695        )
696    }
697    /// Register a matrix with an optional name (names are unique to the [`Cache`] so two different
698    /// registrations of the same type which share a name will also share values and may overwrite
699    /// each other). This method should be called within the
700    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
701    /// resulting [`MatrixID`] should be stored to use later to retrieve the value from the [`Cache`].
702    pub fn register_matrix<const R: usize, const C: usize>(
703        &mut self,
704        name: Option<&str>,
705    ) -> MatrixID<R, C> {
706        let first_index = if let Some(name) = name {
707            *self
708                .matrix_cache_names
709                .entry(name.to_string())
710                .or_insert_with(|| {
711                    self.cache_size += R * C;
712                    self.cache_size - (R * C)
713                })
714        } else {
715            self.cache_size += R * C;
716            self.cache_size - (R * C)
717        };
718        MatrixID(array::from_fn(|i| {
719            array::from_fn(|j| first_index + i * C + j)
720        }))
721    }
722    /// Register a complex-valued matrix with an optional name (names are unique to the [`Cache`] so two different
723    /// registrations of the same type which share a name will also share values and may overwrite
724    /// each other). This method should be called within the
725    /// [`Amplitude::register`](crate::amplitudes::Amplitude::register) method, and the
726    /// resulting [`ComplexMatrixID`] should be stored to use later to retrieve the value from the [`Cache`].
727    pub fn register_complex_matrix<const R: usize, const C: usize>(
728        &mut self,
729        name: Option<&str>,
730    ) -> ComplexMatrixID<R, C> {
731        let first_index = if let Some(name) = name {
732            *self
733                .complex_matrix_cache_names
734                .entry(name.to_string())
735                .or_insert_with(|| {
736                    self.cache_size += 2 * R * C;
737                    self.cache_size - (2 * R * C)
738                })
739        } else {
740            self.cache_size += 2 * R * C;
741            self.cache_size - (2 * R * C)
742        };
743        ComplexMatrixID(
744            array::from_fn(|i| array::from_fn(|j| first_index + i * C + j)),
745            array::from_fn(|i| array::from_fn(|j| (first_index + R * C) + i * C + j)),
746        )
747    }
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753    use nalgebra::{Matrix2, Vector2};
754    use num::complex::Complex64;
755
756    #[test]
757    fn test_parameters() {
758        let parameters = vec![1.0, 2.0, 3.0];
759        let constants = vec![4.0, 5.0, 6.0];
760        let params = Parameters::new(&parameters, &constants);
761
762        assert_eq!(params.get(ParameterID::Parameter(0)), 1.0);
763        assert_eq!(params.get(ParameterID::Parameter(1)), 2.0);
764        assert_eq!(params.get(ParameterID::Parameter(2)), 3.0);
765        assert_eq!(params.get(ParameterID::Constant(0)), 4.0);
766        assert_eq!(params.get(ParameterID::Constant(1)), 5.0);
767        assert_eq!(params.get(ParameterID::Constant(2)), 6.0);
768        assert_eq!(params.len(), 3);
769    }
770
771    #[test]
772    #[should_panic(expected = "Parameter has not been registered!")]
773    fn test_uninit_parameter() {
774        let parameters = vec![1.0];
775        let constants = vec![1.0];
776        let params = Parameters::new(&parameters, &constants);
777        params.get(ParameterID::Uninit);
778    }
779
780    #[test]
781    fn test_resources_amplitude_management() {
782        let mut resources = Resources::default();
783
784        let amp1 = resources.register_amplitude("amp1").unwrap();
785        let amp2 = resources.register_amplitude("amp2").unwrap();
786
787        assert!(resources.active[amp1.1]);
788        assert!(resources.active[amp2.1]);
789
790        resources.deactivate_strict("amp1").unwrap();
791        assert!(!resources.active[amp1.1]);
792        assert!(resources.active[amp2.1]);
793
794        resources.activate_strict("amp1").unwrap();
795        assert!(resources.active[amp1.1]);
796
797        resources.deactivate_all();
798        assert!(!resources.active[amp1.1]);
799        assert!(!resources.active[amp2.1]);
800
801        resources.activate_all();
802        assert!(resources.active[amp1.1]);
803        assert!(resources.active[amp2.1]);
804
805        resources.isolate_strict("amp1").unwrap();
806        assert!(resources.active[amp1.1]);
807        assert!(!resources.active[amp2.1]);
808    }
809
810    #[test]
811    fn test_resources_parameter_registration() {
812        let mut resources = Resources::default();
813
814        let param1 = resources
815            .register_parameter(&ParameterLike::free("param1"))
816            .unwrap();
817        let const1 = resources
818            .register_parameter(&ParameterLike::fixed("const1", 1.0))
819            .unwrap();
820
821        match param1 {
822            ParameterID::Parameter(idx) => assert_eq!(idx, 0),
823            _ => panic!("Expected Parameter variant"),
824        }
825
826        match const1 {
827            ParameterID::Constant(idx) => assert_eq!(idx, 0),
828            _ => panic!("Expected Constant variant"),
829        }
830    }
831
832    #[test]
833    fn test_cache_scalar_operations() {
834        let mut resources = Resources::default();
835
836        let scalar1 = resources.register_scalar(Some("test_scalar"));
837        let scalar2 = resources.register_scalar(None);
838        let scalar3 = resources.register_scalar(Some("test_scalar"));
839
840        resources.reserve_cache(1);
841        let cache = &mut resources.caches[0];
842
843        cache.store_scalar(scalar1, 1.0);
844        cache.store_scalar(scalar2, 2.0);
845
846        assert_eq!(cache.get_scalar(scalar1), 1.0);
847        assert_eq!(cache.get_scalar(scalar2), 2.0);
848        assert_eq!(cache.get_scalar(scalar3), 1.0);
849    }
850
851    #[test]
852    fn test_cache_complex_operations() {
853        let mut resources = Resources::default();
854
855        let complex1 = resources.register_complex_scalar(Some("test_complex"));
856        let complex2 = resources.register_complex_scalar(None);
857        let complex3 = resources.register_complex_scalar(Some("test_complex"));
858
859        resources.reserve_cache(1);
860        let cache = &mut resources.caches[0];
861
862        let value1 = Complex64::new(1.0, 2.0);
863        let value2 = Complex64::new(3.0, 4.0);
864        cache.store_complex_scalar(complex1, value1);
865        cache.store_complex_scalar(complex2, value2);
866
867        assert_eq!(cache.get_complex_scalar(complex1), value1);
868        assert_eq!(cache.get_complex_scalar(complex2), value2);
869        assert_eq!(cache.get_complex_scalar(complex3), value1);
870    }
871
872    #[test]
873    fn test_cache_vector_operations() {
874        let mut resources = Resources::default();
875
876        let vector_id1: VectorID<2> = resources.register_vector(Some("test_vector"));
877        let vector_id2: VectorID<2> = resources.register_vector(None);
878        let vector_id3: VectorID<2> = resources.register_vector(Some("test_vector"));
879
880        resources.reserve_cache(1);
881        let cache = &mut resources.caches[0];
882
883        let value1 = Vector2::new(1.0, 2.0);
884        let value2 = Vector2::new(3.0, 4.0);
885        cache.store_vector(vector_id1, value1);
886        cache.store_vector(vector_id2, value2);
887
888        assert_eq!(cache.get_vector(vector_id1), value1);
889        assert_eq!(cache.get_vector(vector_id2), value2);
890        assert_eq!(cache.get_vector(vector_id3), value1);
891    }
892
893    #[test]
894    fn test_cache_complex_vector_operations() {
895        let mut resources = Resources::default();
896
897        let complex_vector_id1: ComplexVectorID<2> =
898            resources.register_complex_vector(Some("test_complex_vector"));
899        let complex_vector_id2: ComplexVectorID<2> = resources.register_complex_vector(None);
900        let complex_vector_id3: ComplexVectorID<2> =
901            resources.register_complex_vector(Some("test_complex_vector"));
902
903        resources.reserve_cache(1);
904        let cache = &mut resources.caches[0];
905
906        let value1 = Vector2::new(Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0));
907        let value2 = Vector2::new(Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0));
908        cache.store_complex_vector(complex_vector_id1, value1);
909        cache.store_complex_vector(complex_vector_id2, value2);
910
911        assert_eq!(cache.get_complex_vector(complex_vector_id1), value1);
912        assert_eq!(cache.get_complex_vector(complex_vector_id2), value2);
913        assert_eq!(cache.get_complex_vector(complex_vector_id3), value1);
914    }
915
916    #[test]
917    fn test_cache_matrix_operations() {
918        let mut resources = Resources::default();
919
920        let matrix_id1: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
921        let matrix_id2: MatrixID<2, 2> = resources.register_matrix(None);
922        let matrix_id3: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
923
924        resources.reserve_cache(1);
925        let cache = &mut resources.caches[0];
926
927        let value1 = Matrix2::new(1.0, 2.0, 3.0, 4.0);
928        let value2 = Matrix2::new(5.0, 6.0, 7.0, 8.0);
929        cache.store_matrix(matrix_id1, value1);
930        cache.store_matrix(matrix_id2, value2);
931
932        assert_eq!(cache.get_matrix(matrix_id1), value1);
933        assert_eq!(cache.get_matrix(matrix_id2), value2);
934        assert_eq!(cache.get_matrix(matrix_id3), value1);
935    }
936
937    #[test]
938    fn test_cache_complex_matrix_operations() {
939        let mut resources = Resources::default();
940
941        let complex_matrix_id1: ComplexMatrixID<2, 2> =
942            resources.register_complex_matrix(Some("test_complex_matrix"));
943        let complex_matrix_id2: ComplexMatrixID<2, 2> = resources.register_complex_matrix(None);
944        let complex_matrix_id3: ComplexMatrixID<2, 2> =
945            resources.register_complex_matrix(Some("test_complex_matrix"));
946
947        resources.reserve_cache(1);
948        let cache = &mut resources.caches[0];
949
950        let value1 = Matrix2::new(
951            Complex64::new(1.0, 2.0),
952            Complex64::new(3.0, 4.0),
953            Complex64::new(5.0, 6.0),
954            Complex64::new(7.0, 8.0),
955        );
956        let value2 = Matrix2::new(
957            Complex64::new(9.0, 10.0),
958            Complex64::new(11.0, 12.0),
959            Complex64::new(13.0, 14.0),
960            Complex64::new(15.0, 16.0),
961        );
962        cache.store_complex_matrix(complex_matrix_id1, value1);
963        cache.store_complex_matrix(complex_matrix_id2, value2);
964
965        assert_eq!(cache.get_complex_matrix(complex_matrix_id1), value1);
966        assert_eq!(cache.get_complex_matrix(complex_matrix_id2), value2);
967        assert_eq!(cache.get_complex_matrix(complex_matrix_id3), value1);
968    }
969
970    #[test]
971    fn test_uninit_parameter_registration() {
972        let mut resources = Resources::default();
973        let result = resources.register_parameter(&ParameterLike::uninit());
974        assert!(result.is_err());
975    }
976
977    #[test]
978    fn test_duplicate_named_amplitude_registration_error() {
979        let mut resources = Resources::default();
980        assert!(resources.register_amplitude("test_amp").is_ok());
981        assert!(resources.register_amplitude("test_amp").is_err());
982    }
983}