1use std::{array, collections::HashMap};
2
3use indexmap::IndexSet;
4use nalgebra::{SMatrix, SVector};
5use num::Complex;
6use serde::{Deserialize, Serialize};
7use serde_with::serde_as;
8
9use crate::{
10    amplitudes::{AmplitudeID, ParameterLike},
11    Float, LadduError,
12};
13
14#[derive(Debug)]
17pub struct Parameters<'a> {
18    pub(crate) parameters: &'a [Float],
19    pub(crate) constants: &'a [Float],
20}
21
22impl<'a> Parameters<'a> {
23    pub fn new(parameters: &'a [Float], constants: &'a [Float]) -> Self {
25        Self {
26            parameters,
27            constants,
28        }
29    }
30
31    pub fn get(&self, pid: ParameterID) -> Float {
33        match pid {
34            ParameterID::Parameter(index) => self.parameters[index],
35            ParameterID::Constant(index) => self.constants[index],
36            ParameterID::Uninit => panic!("Parameter has not been registered!"),
37        }
38    }
39
40    #[allow(clippy::len_without_is_empty)]
42    pub fn len(&self) -> usize {
43        self.parameters.len()
44    }
45}
46
47#[derive(Default, Debug, Clone, Serialize, Deserialize)]
49pub struct Resources {
50    amplitudes: HashMap<String, AmplitudeID>,
51    pub active: Vec<bool>,
53    pub parameters: IndexSet<String>,
55    pub constants: Vec<Float>,
57    pub caches: Vec<Cache>,
59    scalar_cache_names: HashMap<String, usize>,
60    complex_scalar_cache_names: HashMap<String, usize>,
61    vector_cache_names: HashMap<String, usize>,
62    complex_vector_cache_names: HashMap<String, usize>,
63    matrix_cache_names: HashMap<String, usize>,
64    complex_matrix_cache_names: HashMap<String, usize>,
65    cache_size: usize,
66}
67
68#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct Cache(Vec<Float>);
72impl Cache {
73    fn new(cache_size: usize) -> Self {
74        Self(vec![0.0; cache_size])
75    }
76    pub fn store_scalar(&mut self, sid: ScalarID, value: Float) {
78        self.0[sid.0] = value;
79    }
80    pub fn store_complex_scalar(&mut self, csid: ComplexScalarID, value: Complex<Float>) {
82        self.0[csid.0] = value.re;
83        self.0[csid.1] = value.im;
84    }
85    pub fn store_vector<const R: usize>(&mut self, vid: VectorID<R>, value: SVector<Float, R>) {
87        vid.0
88            .into_iter()
89            .enumerate()
90            .for_each(|(vi, i)| self.0[i] = value[vi]);
91    }
92    pub fn store_complex_vector<const R: usize>(
94        &mut self,
95        cvid: ComplexVectorID<R>,
96        value: SVector<Complex<Float>, R>,
97    ) {
98        cvid.0
99            .into_iter()
100            .enumerate()
101            .for_each(|(vi, i)| self.0[i] = value[vi].re);
102        cvid.1
103            .into_iter()
104            .enumerate()
105            .for_each(|(vi, i)| self.0[i] = value[vi].im);
106    }
107    pub fn store_matrix<const R: usize, const C: usize>(
109        &mut self,
110        mid: MatrixID<R, C>,
111        value: SMatrix<Float, R, C>,
112    ) {
113        mid.0.into_iter().enumerate().for_each(|(vi, row)| {
114            row.into_iter()
115                .enumerate()
116                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)])
117        });
118    }
119    pub fn store_complex_matrix<const R: usize, const C: usize>(
121        &mut self,
122        cmid: ComplexMatrixID<R, C>,
123        value: SMatrix<Complex<Float>, R, C>,
124    ) {
125        cmid.0.into_iter().enumerate().for_each(|(vi, row)| {
126            row.into_iter()
127                .enumerate()
128                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)].re)
129        });
130        cmid.1.into_iter().enumerate().for_each(|(vi, row)| {
131            row.into_iter()
132                .enumerate()
133                .for_each(|(vj, k)| self.0[k] = value[(vi, vj)].im)
134        });
135    }
136    pub fn get_scalar(&self, sid: ScalarID) -> Float {
138        self.0[sid.0]
139    }
140    pub fn get_complex_scalar(&self, csid: ComplexScalarID) -> Complex<Float> {
142        Complex::new(self.0[csid.0], self.0[csid.1])
143    }
144    pub fn get_vector<const R: usize>(&self, vid: VectorID<R>) -> SVector<Float, R> {
146        SVector::from_fn(|i, _| self.0[vid.0[i]])
147    }
148    pub fn get_complex_vector<const R: usize>(
150        &self,
151        cvid: ComplexVectorID<R>,
152    ) -> SVector<Complex<Float>, R> {
153        SVector::from_fn(|i, _| Complex::new(self.0[cvid.0[i]], self.0[cvid.1[i]]))
154    }
155    pub fn get_matrix<const R: usize, const C: usize>(
157        &self,
158        mid: MatrixID<R, C>,
159    ) -> SMatrix<Float, R, C> {
160        SMatrix::from_fn(|i, j| self.0[mid.0[i][j]])
161    }
162    pub fn get_complex_matrix<const R: usize, const C: usize>(
164        &self,
165        cmid: ComplexMatrixID<R, C>,
166    ) -> SMatrix<Complex<Float>, R, C> {
167        SMatrix::from_fn(|i, j| Complex::new(self.0[cmid.0[i][j]], self.0[cmid.1[i][j]]))
168    }
169}
170
171#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize)]
173pub enum ParameterID {
174    Parameter(usize),
176    Constant(usize),
178    #[default]
180    Uninit,
181}
182
183#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
185pub struct ScalarID(usize);
186
187#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
189pub struct ComplexScalarID(usize, usize);
190
191#[serde_as]
193#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
194pub struct VectorID<const R: usize>(#[serde_as(as = "[_; R]")] [usize; R]);
195
196impl<const R: usize> Default for VectorID<R> {
197    fn default() -> Self {
198        Self([0; R])
199    }
200}
201
202#[serde_as]
204#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
205pub struct ComplexVectorID<const R: usize>(
206    #[serde_as(as = "[_; R]")] [usize; R],
207    #[serde_as(as = "[_; R]")] [usize; R],
208);
209
210impl<const R: usize> Default for ComplexVectorID<R> {
211    fn default() -> Self {
212        Self([0; R], [0; R])
213    }
214}
215
216#[serde_as]
218#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
219pub struct MatrixID<const R: usize, const C: usize>(
220    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
221);
222
223impl<const R: usize, const C: usize> Default for MatrixID<R, C> {
224    fn default() -> Self {
225        Self([[0; C]; R])
226    }
227}
228
229#[serde_as]
231#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
232pub struct ComplexMatrixID<const R: usize, const C: usize>(
233    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
234    #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
235);
236
237impl<const R: usize, const C: usize> Default for ComplexMatrixID<R, C> {
238    fn default() -> Self {
239        Self([[0; C]; R], [[0; C]; R])
240    }
241}
242
243impl Resources {
244    pub fn activate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
246        self.active[self
247            .amplitudes
248            .get(name.as_ref())
249            .ok_or(LadduError::AmplitudeNotFoundError {
250                name: name.as_ref().to_string(),
251            })?
252            .1] = true;
253        Ok(())
254    }
255    pub fn activate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
257        for name in names {
258            self.activate(name)?
259        }
260        Ok(())
261    }
262    pub fn activate_all(&mut self) {
264        self.active = vec![true; self.active.len()];
265    }
266    pub fn deactivate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
268        self.active[self
269            .amplitudes
270            .get(name.as_ref())
271            .ok_or(LadduError::AmplitudeNotFoundError {
272                name: name.as_ref().to_string(),
273            })?
274            .1] = false;
275        Ok(())
276    }
277    pub fn deactivate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
279        for name in names {
280            self.deactivate(name)?;
281        }
282        Ok(())
283    }
284    pub fn deactivate_all(&mut self) {
286        self.active = vec![false; self.active.len()];
287    }
288    pub fn isolate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
290        self.deactivate_all();
291        self.activate(name)
292    }
293    pub fn isolate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
295        self.deactivate_all();
296        self.activate_many(names)
297    }
298    pub fn register_amplitude(&mut self, name: &str) -> Result<AmplitudeID, LadduError> {
309        if self.amplitudes.contains_key(name) {
310            return Err(LadduError::RegistrationError {
311                name: name.to_string(),
312            });
313        }
314        let next_id = AmplitudeID(name.to_string(), self.amplitudes.len());
315        self.amplitudes.insert(name.to_string(), next_id.clone());
316        self.active.push(true);
317        Ok(next_id)
318    }
319    pub fn register_parameter(&mut self, pl: &ParameterLike) -> ParameterID {
324        match pl {
325            ParameterLike::Parameter(name) => {
326                let (index, _) = self.parameters.insert_full(name.to_string());
327                ParameterID::Parameter(index)
328            }
329            ParameterLike::Constant(value) => {
330                self.constants.push(*value);
331                ParameterID::Constant(self.constants.len() - 1)
332            }
333            ParameterLike::Uninit => panic!("Parameter was not initialized!"),
334        }
335    }
336    pub(crate) fn reserve_cache(&mut self, num_events: usize) {
337        self.caches = vec![Cache::new(self.cache_size); num_events]
338    }
339    pub fn register_scalar(&mut self, name: Option<&str>) -> ScalarID {
345        let first_index = if let Some(name) = name {
346            *self
347                .scalar_cache_names
348                .entry(name.to_string())
349                .or_insert_with(|| {
350                    self.cache_size += 1;
351                    self.cache_size - 1
352                })
353        } else {
354            self.cache_size += 1;
355            self.cache_size - 1
356        };
357        ScalarID(first_index)
358    }
359    pub fn register_complex_scalar(&mut self, name: Option<&str>) -> ComplexScalarID {
365        let first_index = if let Some(name) = name {
366            *self
367                .complex_scalar_cache_names
368                .entry(name.to_string())
369                .or_insert_with(|| {
370                    self.cache_size += 2;
371                    self.cache_size - 2
372                })
373        } else {
374            self.cache_size += 2;
375            self.cache_size - 2
376        };
377        ComplexScalarID(first_index, first_index + 1)
378    }
379    pub fn register_vector<const R: usize>(&mut self, name: Option<&str>) -> VectorID<R> {
385        let first_index = if let Some(name) = name {
386            *self
387                .vector_cache_names
388                .entry(name.to_string())
389                .or_insert_with(|| {
390                    self.cache_size += R;
391                    self.cache_size - R
392                })
393        } else {
394            self.cache_size += R;
395            self.cache_size - R
396        };
397        VectorID(array::from_fn(|i| first_index + i))
398    }
399    pub fn register_complex_vector<const R: usize>(
405        &mut self,
406        name: Option<&str>,
407    ) -> ComplexVectorID<R> {
408        let first_index = if let Some(name) = name {
409            *self
410                .complex_vector_cache_names
411                .entry(name.to_string())
412                .or_insert_with(|| {
413                    self.cache_size += R * 2;
414                    self.cache_size - (R * 2)
415                })
416        } else {
417            self.cache_size += R * 2;
418            self.cache_size - (R * 2)
419        };
420        ComplexVectorID(
421            array::from_fn(|i| first_index + i),
422            array::from_fn(|i| (first_index + R) + i),
423        )
424    }
425    pub fn register_matrix<const R: usize, const C: usize>(
431        &mut self,
432        name: Option<&str>,
433    ) -> MatrixID<R, C> {
434        let first_index = if let Some(name) = name {
435            *self
436                .matrix_cache_names
437                .entry(name.to_string())
438                .or_insert_with(|| {
439                    self.cache_size += R * C;
440                    self.cache_size - (R * C)
441                })
442        } else {
443            self.cache_size += R * C;
444            self.cache_size - (R * C)
445        };
446        MatrixID(array::from_fn(|i| {
447            array::from_fn(|j| first_index + i * C + j)
448        }))
449    }
450    pub fn register_complex_matrix<const R: usize, const C: usize>(
456        &mut self,
457        name: Option<&str>,
458    ) -> ComplexMatrixID<R, C> {
459        let first_index = if let Some(name) = name {
460            *self
461                .complex_matrix_cache_names
462                .entry(name.to_string())
463                .or_insert_with(|| {
464                    self.cache_size += 2 * R * C;
465                    self.cache_size - (2 * R * C)
466                })
467        } else {
468            self.cache_size += 2 * R * C;
469            self.cache_size - (2 * R * C)
470        };
471        ComplexMatrixID(
472            array::from_fn(|i| array::from_fn(|j| first_index + i * C + j)),
473            array::from_fn(|i| array::from_fn(|j| (first_index + R * C) + i * C + j)),
474        )
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use nalgebra::{Matrix2, Vector2};
482    use num::Complex;
483
484    #[test]
485    fn test_parameters() {
486        let parameters = vec![1.0, 2.0, 3.0];
487        let constants = vec![4.0, 5.0, 6.0];
488        let params = Parameters::new(¶meters, &constants);
489
490        assert_eq!(params.get(ParameterID::Parameter(0)), 1.0);
491        assert_eq!(params.get(ParameterID::Parameter(1)), 2.0);
492        assert_eq!(params.get(ParameterID::Parameter(2)), 3.0);
493        assert_eq!(params.get(ParameterID::Constant(0)), 4.0);
494        assert_eq!(params.get(ParameterID::Constant(1)), 5.0);
495        assert_eq!(params.get(ParameterID::Constant(2)), 6.0);
496        assert_eq!(params.len(), 3);
497    }
498
499    #[test]
500    #[should_panic(expected = "Parameter has not been registered!")]
501    fn test_uninit_parameter() {
502        let parameters = vec![1.0];
503        let constants = vec![1.0];
504        let params = Parameters::new(¶meters, &constants);
505        params.get(ParameterID::Uninit);
506    }
507
508    #[test]
509    fn test_resources_amplitude_management() {
510        let mut resources = Resources::default();
511
512        let amp1 = resources.register_amplitude("amp1").unwrap();
513        let amp2 = resources.register_amplitude("amp2").unwrap();
514
515        assert!(resources.active[amp1.1]);
516        assert!(resources.active[amp2.1]);
517
518        resources.deactivate("amp1").unwrap();
519        assert!(!resources.active[amp1.1]);
520        assert!(resources.active[amp2.1]);
521
522        resources.activate("amp1").unwrap();
523        assert!(resources.active[amp1.1]);
524
525        resources.deactivate_all();
526        assert!(!resources.active[amp1.1]);
527        assert!(!resources.active[amp2.1]);
528
529        resources.activate_all();
530        assert!(resources.active[amp1.1]);
531        assert!(resources.active[amp2.1]);
532
533        resources.isolate("amp1").unwrap();
534        assert!(resources.active[amp1.1]);
535        assert!(!resources.active[amp2.1]);
536    }
537
538    #[test]
539    fn test_resources_parameter_registration() {
540        let mut resources = Resources::default();
541
542        let param1 = resources.register_parameter(&ParameterLike::Parameter("param1".to_string()));
543        let const1 = resources.register_parameter(&ParameterLike::Constant(1.0));
544
545        match param1 {
546            ParameterID::Parameter(idx) => assert_eq!(idx, 0),
547            _ => panic!("Expected Parameter variant"),
548        }
549
550        match const1 {
551            ParameterID::Constant(idx) => assert_eq!(idx, 0),
552            _ => panic!("Expected Constant variant"),
553        }
554    }
555
556    #[test]
557    fn test_cache_scalar_operations() {
558        let mut resources = Resources::default();
559
560        let scalar1 = resources.register_scalar(Some("test_scalar"));
561        let scalar2 = resources.register_scalar(None);
562        let scalar3 = resources.register_scalar(Some("test_scalar"));
563
564        resources.reserve_cache(1);
565        let cache = &mut resources.caches[0];
566
567        cache.store_scalar(scalar1, 1.0);
568        cache.store_scalar(scalar2, 2.0);
569
570        assert_eq!(cache.get_scalar(scalar1), 1.0);
571        assert_eq!(cache.get_scalar(scalar2), 2.0);
572        assert_eq!(cache.get_scalar(scalar3), 1.0);
573    }
574
575    #[test]
576    fn test_cache_complex_operations() {
577        let mut resources = Resources::default();
578
579        let complex1 = resources.register_complex_scalar(Some("test_complex"));
580        let complex2 = resources.register_complex_scalar(None);
581        let complex3 = resources.register_complex_scalar(Some("test_complex"));
582
583        resources.reserve_cache(1);
584        let cache = &mut resources.caches[0];
585
586        let value1 = Complex::new(1.0, 2.0);
587        let value2 = Complex::new(3.0, 4.0);
588        cache.store_complex_scalar(complex1, value1);
589        cache.store_complex_scalar(complex2, value2);
590
591        assert_eq!(cache.get_complex_scalar(complex1), value1);
592        assert_eq!(cache.get_complex_scalar(complex2), value2);
593        assert_eq!(cache.get_complex_scalar(complex3), value1);
594    }
595
596    #[test]
597    fn test_cache_vector_operations() {
598        let mut resources = Resources::default();
599
600        let vector_id1: VectorID<2> = resources.register_vector(Some("test_vector"));
601        let vector_id2: VectorID<2> = resources.register_vector(None);
602        let vector_id3: VectorID<2> = resources.register_vector(Some("test_vector"));
603
604        resources.reserve_cache(1);
605        let cache = &mut resources.caches[0];
606
607        let value1 = Vector2::new(1.0, 2.0);
608        let value2 = Vector2::new(3.0, 4.0);
609        cache.store_vector(vector_id1, value1);
610        cache.store_vector(vector_id2, value2);
611
612        assert_eq!(cache.get_vector(vector_id1), value1);
613        assert_eq!(cache.get_vector(vector_id2), value2);
614        assert_eq!(cache.get_vector(vector_id3), value1);
615    }
616
617    #[test]
618    fn test_cache_complex_vector_operations() {
619        let mut resources = Resources::default();
620
621        let complex_vector_id1: ComplexVectorID<2> =
622            resources.register_complex_vector(Some("test_complex_vector"));
623        let complex_vector_id2: ComplexVectorID<2> = resources.register_complex_vector(None);
624        let complex_vector_id3: ComplexVectorID<2> =
625            resources.register_complex_vector(Some("test_complex_vector"));
626
627        resources.reserve_cache(1);
628        let cache = &mut resources.caches[0];
629
630        let value1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
631        let value2 = Vector2::new(Complex::new(5.0, 6.0), Complex::new(7.0, 8.0));
632        cache.store_complex_vector(complex_vector_id1, value1);
633        cache.store_complex_vector(complex_vector_id2, value2);
634
635        assert_eq!(cache.get_complex_vector(complex_vector_id1), value1);
636        assert_eq!(cache.get_complex_vector(complex_vector_id2), value2);
637        assert_eq!(cache.get_complex_vector(complex_vector_id3), value1);
638    }
639
640    #[test]
641    fn test_cache_matrix_operations() {
642        let mut resources = Resources::default();
643
644        let matrix_id1: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
645        let matrix_id2: MatrixID<2, 2> = resources.register_matrix(None);
646        let matrix_id3: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
647
648        resources.reserve_cache(1);
649        let cache = &mut resources.caches[0];
650
651        let value1 = Matrix2::new(1.0, 2.0, 3.0, 4.0);
652        let value2 = Matrix2::new(5.0, 6.0, 7.0, 8.0);
653        cache.store_matrix(matrix_id1, value1);
654        cache.store_matrix(matrix_id2, value2);
655
656        assert_eq!(cache.get_matrix(matrix_id1), value1);
657        assert_eq!(cache.get_matrix(matrix_id2), value2);
658        assert_eq!(cache.get_matrix(matrix_id3), value1);
659    }
660
661    #[test]
662    fn test_cache_complex_matrix_operations() {
663        let mut resources = Resources::default();
664
665        let complex_matrix_id1: ComplexMatrixID<2, 2> =
666            resources.register_complex_matrix(Some("test_complex_matrix"));
667        let complex_matrix_id2: ComplexMatrixID<2, 2> = resources.register_complex_matrix(None);
668        let complex_matrix_id3: ComplexMatrixID<2, 2> =
669            resources.register_complex_matrix(Some("test_complex_matrix"));
670
671        resources.reserve_cache(1);
672        let cache = &mut resources.caches[0];
673
674        let value1 = Matrix2::new(
675            Complex::new(1.0, 2.0),
676            Complex::new(3.0, 4.0),
677            Complex::new(5.0, 6.0),
678            Complex::new(7.0, 8.0),
679        );
680        let value2 = Matrix2::new(
681            Complex::new(9.0, 10.0),
682            Complex::new(11.0, 12.0),
683            Complex::new(13.0, 14.0),
684            Complex::new(15.0, 16.0),
685        );
686        cache.store_complex_matrix(complex_matrix_id1, value1);
687        cache.store_complex_matrix(complex_matrix_id2, value2);
688
689        assert_eq!(cache.get_complex_matrix(complex_matrix_id1), value1);
690        assert_eq!(cache.get_complex_matrix(complex_matrix_id2), value2);
691        assert_eq!(cache.get_complex_matrix(complex_matrix_id3), value1);
692    }
693
694    #[test]
695    #[should_panic(expected = "Parameter was not initialized!")]
696    fn test_uninit_parameter_registration() {
697        let mut resources = Resources::default();
698        resources.register_parameter(&ParameterLike::Uninit);
699    }
700
701    #[test]
702    fn test_duplicate_named_amplitude_registration_error() {
703        let mut resources = Resources::default();
704        assert!(resources.register_amplitude("test_amp").is_ok());
705        assert!(resources.register_amplitude("test_amp").is_err());
706    }
707}