1use std::{array, collections::HashMap};
2
3use indexmap::IndexSet;
4use nalgebra::{SMatrix, SVector};
5use num::Complex;
6use serde::{Deserialize, Serialize};
7use serde_with::serde_as;
8
9use crate::{
10 amplitudes::{AmplitudeID, ParameterLike},
11 Float, LadduError,
12};
13
14#[derive(Debug)]
17pub struct Parameters<'a> {
18 pub(crate) parameters: &'a [Float],
19 pub(crate) constants: &'a [Float],
20}
21
22impl<'a> Parameters<'a> {
23 pub fn new(parameters: &'a [Float], constants: &'a [Float]) -> Self {
25 Self {
26 parameters,
27 constants,
28 }
29 }
30
31 pub fn get(&self, pid: ParameterID) -> Float {
33 match pid {
34 ParameterID::Parameter(index) => self.parameters[index],
35 ParameterID::Constant(index) => self.constants[index],
36 ParameterID::Uninit => panic!("Parameter has not been registered!"),
37 }
38 }
39
40 #[allow(clippy::len_without_is_empty)]
42 pub fn len(&self) -> usize {
43 self.parameters.len()
44 }
45}
46
47#[derive(Default, Debug, Clone, Serialize, Deserialize)]
49pub struct Resources {
50 amplitudes: HashMap<String, AmplitudeID>,
51 pub active: Vec<bool>,
53 pub parameters: IndexSet<String>,
55 pub constants: Vec<Float>,
57 pub caches: Vec<Cache>,
59 scalar_cache_names: HashMap<String, usize>,
60 complex_scalar_cache_names: HashMap<String, usize>,
61 vector_cache_names: HashMap<String, usize>,
62 complex_vector_cache_names: HashMap<String, usize>,
63 matrix_cache_names: HashMap<String, usize>,
64 complex_matrix_cache_names: HashMap<String, usize>,
65 cache_size: usize,
66}
67
68#[derive(Clone, Debug, Serialize, Deserialize)]
71pub struct Cache(Vec<Float>);
72impl Cache {
73 fn new(cache_size: usize) -> Self {
74 Self(vec![0.0; cache_size])
75 }
76 pub fn store_scalar(&mut self, sid: ScalarID, value: Float) {
78 self.0[sid.0] = value;
79 }
80 pub fn store_complex_scalar(&mut self, csid: ComplexScalarID, value: Complex<Float>) {
82 self.0[csid.0] = value.re;
83 self.0[csid.1] = value.im;
84 }
85 pub fn store_vector<const R: usize>(&mut self, vid: VectorID<R>, value: SVector<Float, R>) {
87 vid.0
88 .into_iter()
89 .enumerate()
90 .for_each(|(vi, i)| self.0[i] = value[vi]);
91 }
92 pub fn store_complex_vector<const R: usize>(
94 &mut self,
95 cvid: ComplexVectorID<R>,
96 value: SVector<Complex<Float>, R>,
97 ) {
98 cvid.0
99 .into_iter()
100 .enumerate()
101 .for_each(|(vi, i)| self.0[i] = value[vi].re);
102 cvid.1
103 .into_iter()
104 .enumerate()
105 .for_each(|(vi, i)| self.0[i] = value[vi].im);
106 }
107 pub fn store_matrix<const R: usize, const C: usize>(
109 &mut self,
110 mid: MatrixID<R, C>,
111 value: SMatrix<Float, R, C>,
112 ) {
113 mid.0.into_iter().enumerate().for_each(|(vi, row)| {
114 row.into_iter()
115 .enumerate()
116 .for_each(|(vj, k)| self.0[k] = value[(vi, vj)])
117 });
118 }
119 pub fn store_complex_matrix<const R: usize, const C: usize>(
121 &mut self,
122 cmid: ComplexMatrixID<R, C>,
123 value: SMatrix<Complex<Float>, R, C>,
124 ) {
125 cmid.0.into_iter().enumerate().for_each(|(vi, row)| {
126 row.into_iter()
127 .enumerate()
128 .for_each(|(vj, k)| self.0[k] = value[(vi, vj)].re)
129 });
130 cmid.1.into_iter().enumerate().for_each(|(vi, row)| {
131 row.into_iter()
132 .enumerate()
133 .for_each(|(vj, k)| self.0[k] = value[(vi, vj)].im)
134 });
135 }
136 pub fn get_scalar(&self, sid: ScalarID) -> Float {
138 self.0[sid.0]
139 }
140 pub fn get_complex_scalar(&self, csid: ComplexScalarID) -> Complex<Float> {
142 Complex::new(self.0[csid.0], self.0[csid.1])
143 }
144 pub fn get_vector<const R: usize>(&self, vid: VectorID<R>) -> SVector<Float, R> {
146 SVector::from_fn(|i, _| self.0[vid.0[i]])
147 }
148 pub fn get_complex_vector<const R: usize>(
150 &self,
151 cvid: ComplexVectorID<R>,
152 ) -> SVector<Complex<Float>, R> {
153 SVector::from_fn(|i, _| Complex::new(self.0[cvid.0[i]], self.0[cvid.1[i]]))
154 }
155 pub fn get_matrix<const R: usize, const C: usize>(
157 &self,
158 mid: MatrixID<R, C>,
159 ) -> SMatrix<Float, R, C> {
160 SMatrix::from_fn(|i, j| self.0[mid.0[i][j]])
161 }
162 pub fn get_complex_matrix<const R: usize, const C: usize>(
164 &self,
165 cmid: ComplexMatrixID<R, C>,
166 ) -> SMatrix<Complex<Float>, R, C> {
167 SMatrix::from_fn(|i, j| Complex::new(self.0[cmid.0[i][j]], self.0[cmid.1[i][j]]))
168 }
169}
170
171#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize)]
173pub enum ParameterID {
174 Parameter(usize),
176 Constant(usize),
178 #[default]
180 Uninit,
181}
182
183#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
185pub struct ScalarID(usize);
186
187#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
189pub struct ComplexScalarID(usize, usize);
190
191#[serde_as]
193#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
194pub struct VectorID<const R: usize>(#[serde_as(as = "[_; R]")] [usize; R]);
195
196impl<const R: usize> Default for VectorID<R> {
197 fn default() -> Self {
198 Self([0; R])
199 }
200}
201
202#[serde_as]
204#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
205pub struct ComplexVectorID<const R: usize>(
206 #[serde_as(as = "[_; R]")] [usize; R],
207 #[serde_as(as = "[_; R]")] [usize; R],
208);
209
210impl<const R: usize> Default for ComplexVectorID<R> {
211 fn default() -> Self {
212 Self([0; R], [0; R])
213 }
214}
215
216#[serde_as]
218#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
219pub struct MatrixID<const R: usize, const C: usize>(
220 #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
221);
222
223impl<const R: usize, const C: usize> Default for MatrixID<R, C> {
224 fn default() -> Self {
225 Self([[0; C]; R])
226 }
227}
228
229#[serde_as]
231#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
232pub struct ComplexMatrixID<const R: usize, const C: usize>(
233 #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
234 #[serde_as(as = "[[_; C]; R]")] [[usize; C]; R],
235);
236
237impl<const R: usize, const C: usize> Default for ComplexMatrixID<R, C> {
238 fn default() -> Self {
239 Self([[0; C]; R], [[0; C]; R])
240 }
241}
242
243impl Resources {
244 pub fn activate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
246 self.active[self
247 .amplitudes
248 .get(name.as_ref())
249 .ok_or(LadduError::AmplitudeNotFoundError {
250 name: name.as_ref().to_string(),
251 })?
252 .1] = true;
253 Ok(())
254 }
255 pub fn activate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
257 for name in names {
258 self.activate(name)?
259 }
260 Ok(())
261 }
262 pub fn activate_all(&mut self) {
264 self.active = vec![true; self.active.len()];
265 }
266 pub fn deactivate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
268 self.active[self
269 .amplitudes
270 .get(name.as_ref())
271 .ok_or(LadduError::AmplitudeNotFoundError {
272 name: name.as_ref().to_string(),
273 })?
274 .1] = false;
275 Ok(())
276 }
277 pub fn deactivate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
279 for name in names {
280 self.deactivate(name)?;
281 }
282 Ok(())
283 }
284 pub fn deactivate_all(&mut self) {
286 self.active = vec![false; self.active.len()];
287 }
288 pub fn isolate<T: AsRef<str>>(&mut self, name: T) -> Result<(), LadduError> {
290 self.deactivate_all();
291 self.activate(name)
292 }
293 pub fn isolate_many<T: AsRef<str>>(&mut self, names: &[T]) -> Result<(), LadduError> {
295 self.deactivate_all();
296 self.activate_many(names)
297 }
298 pub fn register_amplitude(&mut self, name: &str) -> Result<AmplitudeID, LadduError> {
309 if self.amplitudes.contains_key(name) {
310 return Err(LadduError::RegistrationError {
311 name: name.to_string(),
312 });
313 }
314 let next_id = AmplitudeID(name.to_string(), self.amplitudes.len());
315 self.amplitudes.insert(name.to_string(), next_id.clone());
316 self.active.push(true);
317 Ok(next_id)
318 }
319 pub fn register_parameter(&mut self, pl: &ParameterLike) -> ParameterID {
324 match pl {
325 ParameterLike::Parameter(name) => {
326 let (index, _) = self.parameters.insert_full(name.to_string());
327 ParameterID::Parameter(index)
328 }
329 ParameterLike::Constant(value) => {
330 self.constants.push(*value);
331 ParameterID::Constant(self.constants.len() - 1)
332 }
333 ParameterLike::Uninit => panic!("Parameter was not initialized!"),
334 }
335 }
336 pub(crate) fn reserve_cache(&mut self, num_events: usize) {
337 self.caches = vec![Cache::new(self.cache_size); num_events]
338 }
339 pub fn register_scalar(&mut self, name: Option<&str>) -> ScalarID {
345 let first_index = if let Some(name) = name {
346 *self
347 .scalar_cache_names
348 .entry(name.to_string())
349 .or_insert_with(|| {
350 self.cache_size += 1;
351 self.cache_size - 1
352 })
353 } else {
354 self.cache_size += 1;
355 self.cache_size - 1
356 };
357 ScalarID(first_index)
358 }
359 pub fn register_complex_scalar(&mut self, name: Option<&str>) -> ComplexScalarID {
365 let first_index = if let Some(name) = name {
366 *self
367 .complex_scalar_cache_names
368 .entry(name.to_string())
369 .or_insert_with(|| {
370 self.cache_size += 2;
371 self.cache_size - 2
372 })
373 } else {
374 self.cache_size += 2;
375 self.cache_size - 2
376 };
377 ComplexScalarID(first_index, first_index + 1)
378 }
379 pub fn register_vector<const R: usize>(&mut self, name: Option<&str>) -> VectorID<R> {
385 let first_index = if let Some(name) = name {
386 *self
387 .vector_cache_names
388 .entry(name.to_string())
389 .or_insert_with(|| {
390 self.cache_size += R;
391 self.cache_size - R
392 })
393 } else {
394 self.cache_size += R;
395 self.cache_size - R
396 };
397 VectorID(array::from_fn(|i| first_index + i))
398 }
399 pub fn register_complex_vector<const R: usize>(
405 &mut self,
406 name: Option<&str>,
407 ) -> ComplexVectorID<R> {
408 let first_index = if let Some(name) = name {
409 *self
410 .complex_vector_cache_names
411 .entry(name.to_string())
412 .or_insert_with(|| {
413 self.cache_size += R * 2;
414 self.cache_size - (R * 2)
415 })
416 } else {
417 self.cache_size += R * 2;
418 self.cache_size - (R * 2)
419 };
420 ComplexVectorID(
421 array::from_fn(|i| first_index + i),
422 array::from_fn(|i| (first_index + R) + i),
423 )
424 }
425 pub fn register_matrix<const R: usize, const C: usize>(
431 &mut self,
432 name: Option<&str>,
433 ) -> MatrixID<R, C> {
434 let first_index = if let Some(name) = name {
435 *self
436 .matrix_cache_names
437 .entry(name.to_string())
438 .or_insert_with(|| {
439 self.cache_size += R * C;
440 self.cache_size - (R * C)
441 })
442 } else {
443 self.cache_size += R * C;
444 self.cache_size - (R * C)
445 };
446 MatrixID(array::from_fn(|i| {
447 array::from_fn(|j| first_index + i * C + j)
448 }))
449 }
450 pub fn register_complex_matrix<const R: usize, const C: usize>(
456 &mut self,
457 name: Option<&str>,
458 ) -> ComplexMatrixID<R, C> {
459 let first_index = if let Some(name) = name {
460 *self
461 .complex_matrix_cache_names
462 .entry(name.to_string())
463 .or_insert_with(|| {
464 self.cache_size += 2 * R * C;
465 self.cache_size - (2 * R * C)
466 })
467 } else {
468 self.cache_size += 2 * R * C;
469 self.cache_size - (2 * R * C)
470 };
471 ComplexMatrixID(
472 array::from_fn(|i| array::from_fn(|j| first_index + i * C + j)),
473 array::from_fn(|i| array::from_fn(|j| (first_index + R * C) + i * C + j)),
474 )
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use nalgebra::{Matrix2, Vector2};
482 use num::Complex;
483
484 #[test]
485 fn test_parameters() {
486 let parameters = vec![1.0, 2.0, 3.0];
487 let constants = vec![4.0, 5.0, 6.0];
488 let params = Parameters::new(¶meters, &constants);
489
490 assert_eq!(params.get(ParameterID::Parameter(0)), 1.0);
491 assert_eq!(params.get(ParameterID::Parameter(1)), 2.0);
492 assert_eq!(params.get(ParameterID::Parameter(2)), 3.0);
493 assert_eq!(params.get(ParameterID::Constant(0)), 4.0);
494 assert_eq!(params.get(ParameterID::Constant(1)), 5.0);
495 assert_eq!(params.get(ParameterID::Constant(2)), 6.0);
496 assert_eq!(params.len(), 3);
497 }
498
499 #[test]
500 #[should_panic(expected = "Parameter has not been registered!")]
501 fn test_uninit_parameter() {
502 let parameters = vec![1.0];
503 let constants = vec![1.0];
504 let params = Parameters::new(¶meters, &constants);
505 params.get(ParameterID::Uninit);
506 }
507
508 #[test]
509 fn test_resources_amplitude_management() {
510 let mut resources = Resources::default();
511
512 let amp1 = resources.register_amplitude("amp1").unwrap();
513 let amp2 = resources.register_amplitude("amp2").unwrap();
514
515 assert!(resources.active[amp1.1]);
516 assert!(resources.active[amp2.1]);
517
518 resources.deactivate("amp1").unwrap();
519 assert!(!resources.active[amp1.1]);
520 assert!(resources.active[amp2.1]);
521
522 resources.activate("amp1").unwrap();
523 assert!(resources.active[amp1.1]);
524
525 resources.deactivate_all();
526 assert!(!resources.active[amp1.1]);
527 assert!(!resources.active[amp2.1]);
528
529 resources.activate_all();
530 assert!(resources.active[amp1.1]);
531 assert!(resources.active[amp2.1]);
532
533 resources.isolate("amp1").unwrap();
534 assert!(resources.active[amp1.1]);
535 assert!(!resources.active[amp2.1]);
536 }
537
538 #[test]
539 fn test_resources_parameter_registration() {
540 let mut resources = Resources::default();
541
542 let param1 = resources.register_parameter(&ParameterLike::Parameter("param1".to_string()));
543 let const1 = resources.register_parameter(&ParameterLike::Constant(1.0));
544
545 match param1 {
546 ParameterID::Parameter(idx) => assert_eq!(idx, 0),
547 _ => panic!("Expected Parameter variant"),
548 }
549
550 match const1 {
551 ParameterID::Constant(idx) => assert_eq!(idx, 0),
552 _ => panic!("Expected Constant variant"),
553 }
554 }
555
556 #[test]
557 fn test_cache_scalar_operations() {
558 let mut resources = Resources::default();
559
560 let scalar1 = resources.register_scalar(Some("test_scalar"));
561 let scalar2 = resources.register_scalar(None);
562 let scalar3 = resources.register_scalar(Some("test_scalar"));
563
564 resources.reserve_cache(1);
565 let cache = &mut resources.caches[0];
566
567 cache.store_scalar(scalar1, 1.0);
568 cache.store_scalar(scalar2, 2.0);
569
570 assert_eq!(cache.get_scalar(scalar1), 1.0);
571 assert_eq!(cache.get_scalar(scalar2), 2.0);
572 assert_eq!(cache.get_scalar(scalar3), 1.0);
573 }
574
575 #[test]
576 fn test_cache_complex_operations() {
577 let mut resources = Resources::default();
578
579 let complex1 = resources.register_complex_scalar(Some("test_complex"));
580 let complex2 = resources.register_complex_scalar(None);
581 let complex3 = resources.register_complex_scalar(Some("test_complex"));
582
583 resources.reserve_cache(1);
584 let cache = &mut resources.caches[0];
585
586 let value1 = Complex::new(1.0, 2.0);
587 let value2 = Complex::new(3.0, 4.0);
588 cache.store_complex_scalar(complex1, value1);
589 cache.store_complex_scalar(complex2, value2);
590
591 assert_eq!(cache.get_complex_scalar(complex1), value1);
592 assert_eq!(cache.get_complex_scalar(complex2), value2);
593 assert_eq!(cache.get_complex_scalar(complex3), value1);
594 }
595
596 #[test]
597 fn test_cache_vector_operations() {
598 let mut resources = Resources::default();
599
600 let vector_id1: VectorID<2> = resources.register_vector(Some("test_vector"));
601 let vector_id2: VectorID<2> = resources.register_vector(None);
602 let vector_id3: VectorID<2> = resources.register_vector(Some("test_vector"));
603
604 resources.reserve_cache(1);
605 let cache = &mut resources.caches[0];
606
607 let value1 = Vector2::new(1.0, 2.0);
608 let value2 = Vector2::new(3.0, 4.0);
609 cache.store_vector(vector_id1, value1);
610 cache.store_vector(vector_id2, value2);
611
612 assert_eq!(cache.get_vector(vector_id1), value1);
613 assert_eq!(cache.get_vector(vector_id2), value2);
614 assert_eq!(cache.get_vector(vector_id3), value1);
615 }
616
617 #[test]
618 fn test_cache_complex_vector_operations() {
619 let mut resources = Resources::default();
620
621 let complex_vector_id1: ComplexVectorID<2> =
622 resources.register_complex_vector(Some("test_complex_vector"));
623 let complex_vector_id2: ComplexVectorID<2> = resources.register_complex_vector(None);
624 let complex_vector_id3: ComplexVectorID<2> =
625 resources.register_complex_vector(Some("test_complex_vector"));
626
627 resources.reserve_cache(1);
628 let cache = &mut resources.caches[0];
629
630 let value1 = Vector2::new(Complex::new(1.0, 2.0), Complex::new(3.0, 4.0));
631 let value2 = Vector2::new(Complex::new(5.0, 6.0), Complex::new(7.0, 8.0));
632 cache.store_complex_vector(complex_vector_id1, value1);
633 cache.store_complex_vector(complex_vector_id2, value2);
634
635 assert_eq!(cache.get_complex_vector(complex_vector_id1), value1);
636 assert_eq!(cache.get_complex_vector(complex_vector_id2), value2);
637 assert_eq!(cache.get_complex_vector(complex_vector_id3), value1);
638 }
639
640 #[test]
641 fn test_cache_matrix_operations() {
642 let mut resources = Resources::default();
643
644 let matrix_id1: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
645 let matrix_id2: MatrixID<2, 2> = resources.register_matrix(None);
646 let matrix_id3: MatrixID<2, 2> = resources.register_matrix(Some("test_matrix"));
647
648 resources.reserve_cache(1);
649 let cache = &mut resources.caches[0];
650
651 let value1 = Matrix2::new(1.0, 2.0, 3.0, 4.0);
652 let value2 = Matrix2::new(5.0, 6.0, 7.0, 8.0);
653 cache.store_matrix(matrix_id1, value1);
654 cache.store_matrix(matrix_id2, value2);
655
656 assert_eq!(cache.get_matrix(matrix_id1), value1);
657 assert_eq!(cache.get_matrix(matrix_id2), value2);
658 assert_eq!(cache.get_matrix(matrix_id3), value1);
659 }
660
661 #[test]
662 fn test_cache_complex_matrix_operations() {
663 let mut resources = Resources::default();
664
665 let complex_matrix_id1: ComplexMatrixID<2, 2> =
666 resources.register_complex_matrix(Some("test_complex_matrix"));
667 let complex_matrix_id2: ComplexMatrixID<2, 2> = resources.register_complex_matrix(None);
668 let complex_matrix_id3: ComplexMatrixID<2, 2> =
669 resources.register_complex_matrix(Some("test_complex_matrix"));
670
671 resources.reserve_cache(1);
672 let cache = &mut resources.caches[0];
673
674 let value1 = Matrix2::new(
675 Complex::new(1.0, 2.0),
676 Complex::new(3.0, 4.0),
677 Complex::new(5.0, 6.0),
678 Complex::new(7.0, 8.0),
679 );
680 let value2 = Matrix2::new(
681 Complex::new(9.0, 10.0),
682 Complex::new(11.0, 12.0),
683 Complex::new(13.0, 14.0),
684 Complex::new(15.0, 16.0),
685 );
686 cache.store_complex_matrix(complex_matrix_id1, value1);
687 cache.store_complex_matrix(complex_matrix_id2, value2);
688
689 assert_eq!(cache.get_complex_matrix(complex_matrix_id1), value1);
690 assert_eq!(cache.get_complex_matrix(complex_matrix_id2), value2);
691 assert_eq!(cache.get_complex_matrix(complex_matrix_id3), value1);
692 }
693
694 #[test]
695 #[should_panic(expected = "Parameter was not initialized!")]
696 fn test_uninit_parameter_registration() {
697 let mut resources = Resources::default();
698 resources.register_parameter(&ParameterLike::Uninit);
699 }
700
701 #[test]
702 fn test_duplicate_named_amplitude_registration_error() {
703 let mut resources = Resources::default();
704 assert!(resources.register_amplitude("test_amp").is_ok());
705 assert!(resources.register_amplitude("test_amp").is_err());
706 }
707}