laddu_amplitudes/scalar/
variable.rs1use laddu_core::{
2 amplitude::{
3 display_key, Amplitude, AmplitudeID, AmplitudeSemanticKey, Expression,
4 ExpressionDependence, IntoTags, Tags,
5 },
6 data::{DatasetMetadata, Event},
7 resources::{Cache, Parameters, Resources},
8 traits::Variable,
9 LadduResult, ScalarID,
10};
11use nalgebra::DVector;
12use num::complex::Complex64;
13use serde::{Deserialize, Serialize};
14
15#[derive(Clone, Serialize, Deserialize)]
17pub struct VariableScalar {
18 tags: Tags,
19 variable: Box<dyn Variable>,
20 value_id: ScalarID,
21}
22
23impl VariableScalar {
24 pub fn new<V: Variable + 'static>(
26 tags: impl IntoTags,
27 variable: &V,
28 ) -> LadduResult<Expression> {
29 Self {
30 tags: tags.into_tags(),
31 variable: dyn_clone::clone_box(variable),
32 value_id: ScalarID::default(),
33 }
34 .into_expression()
35 }
36}
37
38pub trait VariableExpressionExt: Variable + 'static {
40 fn as_expression(&self, tags: impl IntoTags) -> LadduResult<Expression>
42 where
43 Self: Sized,
44 {
45 VariableScalar::new(tags, self)
46 }
47}
48
49impl<T: Variable + 'static> VariableExpressionExt for T {}
50
51#[typetag::serde]
52impl Amplitude for VariableScalar {
53 fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
54 self.value_id = resources.register_scalar(None);
55 resources.register_amplitude(self.tags.clone())
56 }
57
58 fn semantic_key(&self) -> Option<AmplitudeSemanticKey> {
59 Some(
60 AmplitudeSemanticKey::new("VariableScalar")
61 .with_field("variable", display_key(&self.variable)),
62 )
63 }
64
65 fn dependence_hint(&self) -> ExpressionDependence {
66 ExpressionDependence::CacheOnly
67 }
68
69 fn real_valued_hint(&self) -> bool {
70 true
71 }
72
73 fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
74 self.variable.bind(metadata)
75 }
76
77 fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {
78 cache.store_scalar(self.value_id, self.variable.value(event));
79 }
80
81 fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
82 cache.get_scalar(self.value_id).into()
83 }
84
85 fn compute_gradient(
86 &self,
87 _parameters: &Parameters,
88 _cache: &Cache,
89 _gradient: &mut DVector<Complex64>,
90 ) {
91 }
92}