Skip to main content

laddu_amplitudes/scalar/
variable.rs

1use laddu_core::{
2    amplitude::{
3        display_key, Amplitude, AmplitudeID, AmplitudeSemanticKey, Expression,
4        ExpressionDependence, IntoTags, Tags,
5    },
6    data::{DatasetMetadata, Event},
7    resources::{Cache, Parameters, Resources},
8    traits::Variable,
9    LadduResult, ScalarID,
10};
11use nalgebra::DVector;
12use num::complex::Complex64;
13use serde::{Deserialize, Serialize};
14
15/// A real-valued [`Amplitude`] which evaluates an event [`Variable`].
16#[derive(Clone, Serialize, Deserialize)]
17pub struct VariableScalar {
18    tags: Tags,
19    variable: Box<dyn Variable>,
20    value_id: ScalarID,
21}
22
23impl VariableScalar {
24    /// Create a new [`VariableScalar`] that evaluates `variable` on each event.
25    pub fn new<V: Variable + 'static>(
26        tags: impl IntoTags,
27        variable: &V,
28    ) -> LadduResult<Expression> {
29        Self {
30            tags: tags.into_tags(),
31            variable: dyn_clone::clone_box(variable),
32            value_id: ScalarID::default(),
33        }
34        .into_expression()
35    }
36}
37
38/// Extension methods for building expressions from event [`Variable`]s.
39pub trait VariableExpressionExt: Variable + 'static {
40    /// Convert this variable into a real-valued [`Expression`].
41    fn as_expression(&self, tags: impl IntoTags) -> LadduResult<Expression>
42    where
43        Self: Sized,
44    {
45        VariableScalar::new(tags, self)
46    }
47}
48
49impl<T: Variable + 'static> VariableExpressionExt for T {}
50
51#[typetag::serde]
52impl Amplitude for VariableScalar {
53    fn register(&mut self, resources: &mut Resources) -> LadduResult<AmplitudeID> {
54        self.value_id = resources.register_scalar(None);
55        resources.register_amplitude(self.tags.clone())
56    }
57
58    fn semantic_key(&self) -> Option<AmplitudeSemanticKey> {
59        Some(
60            AmplitudeSemanticKey::new("VariableScalar")
61                .with_field("variable", display_key(&self.variable)),
62        )
63    }
64
65    fn dependence_hint(&self) -> ExpressionDependence {
66        ExpressionDependence::CacheOnly
67    }
68
69    fn real_valued_hint(&self) -> bool {
70        true
71    }
72
73    fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
74        self.variable.bind(metadata)
75    }
76
77    fn precompute(&self, event: &Event<'_>, cache: &mut Cache) {
78        cache.store_scalar(self.value_id, self.variable.value(event));
79    }
80
81    fn compute(&self, _parameters: &Parameters, cache: &Cache) -> Complex64 {
82        cache.get_scalar(self.value_id).into()
83    }
84
85    fn compute_gradient(
86        &self,
87        _parameters: &Parameters,
88        _cache: &Cache,
89        _gradient: &mut DVector<Complex64>,
90    ) {
91    }
92}