Skip to main content

laddu_core/variables/
variable.rs

1use std::fmt::{Debug, Display};
2
3use auto_ops::impl_op_ex;
4use dyn_clone::DynClone;
5#[cfg(feature = "mpi")]
6use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
7#[cfg(feature = "rayon")]
8use rayon::prelude::*;
9
10#[cfg(feature = "mpi")]
11use crate::mpi::LadduMPI;
12use crate::{
13    data::{Dataset, DatasetMetadata, EventLike},
14    LadduResult,
15};
16
17/// Standard methods for extracting some value from an event view.
18#[typetag::serde(tag = "type")]
19pub trait Variable: DynClone + Send + Sync + Debug + Display {
20    /// Bind the variable to dataset metadata so that any referenced names can be resolved to
21    /// concrete indices. Implementations that do not require metadata may keep the default
22    /// no-op.
23    fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
24        Ok(())
25    }
26
27    /// This method extracts a single value (like a mass) from an event access view.
28    fn value(&self, event: &dyn EventLike) -> f64;
29
30    /// This method distributes [`Variable::value`] over each event in a [`Dataset`] (non-MPI version).
31    ///
32    /// # Notes
33    ///
34    /// This method is not intended to be called in analyses but rather in writing methods
35    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
36    fn value_on_local(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
37        let mut variable = dyn_clone::clone_box(self);
38        variable.bind(dataset.metadata())?;
39        #[cfg(feature = "rayon")]
40        let local_values: Vec<f64> = (0..dataset.n_events_local())
41            .into_par_iter()
42            .map(|event_index| {
43                let event = dataset.event_view(event_index);
44                variable.value(&event)
45            })
46            .collect();
47        #[cfg(not(feature = "rayon"))]
48        let local_values: Vec<f64> = (0..dataset.n_events_local())
49            .map(|event_index| {
50                let event = dataset.event_view(event_index);
51                variable.value(&event)
52            })
53            .collect();
54        Ok(local_values)
55    }
56
57    /// This method distributes the [`Variable::value`] method over each [`Event`](crate::data::Event) in a
58    /// [`Dataset`] (MPI-compatible version).
59    ///
60    /// # Notes
61    ///
62    /// This method is not intended to be called in analyses but rather in writing methods
63    /// that have `mpi`-feature-gated versions. Most users should just call [`Variable::value_on`] instead.
64    #[cfg(feature = "mpi")]
65    fn value_on_mpi(&self, dataset: &Dataset, world: &SimpleCommunicator) -> LadduResult<Vec<f64>> {
66        let local_weights = self.value_on_local(dataset)?;
67        let n_events = dataset.n_events();
68        let mut buffer: Vec<f64> = vec![0.0; n_events];
69        let (counts, displs) = world.get_counts_displs(n_events);
70        {
71            let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
72            world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
73        }
74        Ok(buffer)
75    }
76
77    /// This method distributes the [`Variable::value`] method over each [`Event`](crate::data::Event) in a
78    /// [`Dataset`].
79    fn value_on(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
80        #[cfg(feature = "mpi")]
81        {
82            if let Some(world) = crate::mpi::get_world() {
83                return self.value_on_mpi(dataset, &world);
84            }
85        }
86        self.value_on_local(dataset)
87    }
88
89    /// Create an [`VariableExpression`] that evaluates to `self == val`
90    fn eq(&self, val: f64) -> VariableExpression
91    where
92        Self: std::marker::Sized + 'static,
93    {
94        VariableExpression::Eq(dyn_clone::clone_box(self), val)
95    }
96
97    /// Create an [`VariableExpression`] that evaluates to `self < val`
98    fn lt(&self, val: f64) -> VariableExpression
99    where
100        Self: std::marker::Sized + 'static,
101    {
102        VariableExpression::Lt(dyn_clone::clone_box(self), val)
103    }
104
105    /// Create an [`VariableExpression`] that evaluates to `self > val`
106    fn gt(&self, val: f64) -> VariableExpression
107    where
108        Self: std::marker::Sized + 'static,
109    {
110        VariableExpression::Gt(dyn_clone::clone_box(self), val)
111    }
112
113    /// Create an [`VariableExpression`] that evaluates to `self >= val`
114    fn ge(&self, val: f64) -> VariableExpression
115    where
116        Self: std::marker::Sized + 'static,
117    {
118        self.gt(val).or(&self.eq(val))
119    }
120
121    /// Create an [`VariableExpression`] that evaluates to `self <= val`
122    fn le(&self, val: f64) -> VariableExpression
123    where
124        Self: std::marker::Sized + 'static,
125    {
126        self.lt(val).or(&self.eq(val))
127    }
128}
129dyn_clone::clone_trait_object!(Variable);
130
131/// Expressions which can be used to compare [`Variable`]s to [`f64`]s.
132#[derive(Clone, Debug)]
133pub enum VariableExpression {
134    /// Expression which is true when the variable is equal to the float.
135    Eq(Box<dyn Variable>, f64),
136    /// Expression which is true when the variable is less than the float.
137    Lt(Box<dyn Variable>, f64),
138    /// Expression which is true when the variable is greater than the float.
139    Gt(Box<dyn Variable>, f64),
140    /// Expression which is true when both inner expressions are true.
141    And(Box<VariableExpression>, Box<VariableExpression>),
142    /// Expression which is true when either inner expression is true.
143    Or(Box<VariableExpression>, Box<VariableExpression>),
144    /// Expression which is true when the inner expression is false.
145    Not(Box<VariableExpression>),
146}
147
148impl VariableExpression {
149    /// Construct an [`VariableExpression::And`] from the current expression and another.
150    pub fn and(&self, rhs: &VariableExpression) -> VariableExpression {
151        VariableExpression::And(Box::new(self.clone()), Box::new(rhs.clone()))
152    }
153
154    /// Construct an [`VariableExpression::Or`] from the current expression and another.
155    pub fn or(&self, rhs: &VariableExpression) -> VariableExpression {
156        VariableExpression::Or(Box::new(self.clone()), Box::new(rhs.clone()))
157    }
158
159    /// Comple the [`VariableExpression`] into a [`CompiledExpression`] bound to the supplied
160    /// metadata so that all variable references are resolved.
161    pub(crate) fn compile(
162        &self,
163        metadata: &DatasetMetadata,
164    ) -> LadduResult<CompiledVariableExpression> {
165        let mut compiled = compile_expression(self.clone());
166        compiled.bind(metadata)?;
167        Ok(compiled)
168    }
169}
170impl Display for VariableExpression {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        match self {
173            VariableExpression::Eq(var, val) => {
174                write!(f, "({} == {})", var, val)
175            }
176            VariableExpression::Lt(var, val) => {
177                write!(f, "({} < {})", var, val)
178            }
179            VariableExpression::Gt(var, val) => {
180                write!(f, "({} > {})", var, val)
181            }
182            VariableExpression::And(lhs, rhs) => {
183                write!(f, "({} & {})", lhs, rhs)
184            }
185            VariableExpression::Or(lhs, rhs) => {
186                write!(f, "({} | {})", lhs, rhs)
187            }
188            VariableExpression::Not(inner) => {
189                write!(f, "!({})", inner)
190            }
191        }
192    }
193}
194
195/// A method which negates the given expression.
196pub fn not(expr: &VariableExpression) -> VariableExpression {
197    VariableExpression::Not(Box::new(expr.clone()))
198}
199
200#[rustfmt::skip]
201impl_op_ex!(& |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.and(rhs) });
202#[rustfmt::skip]
203impl_op_ex!(| |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.or(rhs) });
204#[rustfmt::skip]
205impl_op_ex!(! |exp: &VariableExpression| -> VariableExpression{ not(exp) });
206
207#[derive(Debug)]
208enum Opcode {
209    PushEq(usize, f64),
210    PushLt(usize, f64),
211    PushGt(usize, f64),
212    And,
213    Or,
214    Not,
215}
216
217pub(crate) struct CompiledVariableExpression {
218    bytecode: Vec<Opcode>,
219    variables: Vec<Box<dyn Variable>>,
220}
221
222impl CompiledVariableExpression {
223    pub fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
224        for variable in &mut self.variables {
225            variable.bind(metadata)?;
226        }
227        Ok(())
228    }
229
230    /// Evaluate the [`CompiledExpression`] on a given named event view.
231    pub fn evaluate(&self, event: &dyn EventLike) -> bool {
232        let mut stack = Vec::with_capacity(self.bytecode.len());
233
234        for op in &self.bytecode {
235            match op {
236                Opcode::PushEq(i, val) => stack.push(self.variables[*i].value(event) == *val),
237                Opcode::PushLt(i, val) => stack.push(self.variables[*i].value(event) < *val),
238                Opcode::PushGt(i, val) => stack.push(self.variables[*i].value(event) > *val),
239                Opcode::Not => {
240                    let a = stack.pop().unwrap();
241                    stack.push(!a);
242                }
243                Opcode::And => {
244                    let b = stack.pop().unwrap();
245                    let a = stack.pop().unwrap();
246                    stack.push(a && b);
247                }
248                Opcode::Or => {
249                    let b = stack.pop().unwrap();
250                    let a = stack.pop().unwrap();
251                    stack.push(a || b);
252                }
253            }
254        }
255
256        stack.pop().unwrap()
257    }
258}
259
260pub(crate) fn compile_expression(expr: VariableExpression) -> CompiledVariableExpression {
261    let mut bytecode = Vec::new();
262    let mut variables: Vec<Box<dyn Variable>> = Vec::new();
263
264    fn compile(
265        expr: VariableExpression,
266        bytecode: &mut Vec<Opcode>,
267        variables: &mut Vec<Box<dyn Variable>>,
268    ) {
269        match expr {
270            VariableExpression::Eq(var, val) => {
271                variables.push(var);
272                bytecode.push(Opcode::PushEq(variables.len() - 1, val));
273            }
274            VariableExpression::Lt(var, val) => {
275                variables.push(var);
276                bytecode.push(Opcode::PushLt(variables.len() - 1, val));
277            }
278            VariableExpression::Gt(var, val) => {
279                variables.push(var);
280                bytecode.push(Opcode::PushGt(variables.len() - 1, val));
281            }
282            VariableExpression::And(lhs, rhs) => {
283                compile(*lhs, bytecode, variables);
284                compile(*rhs, bytecode, variables);
285                bytecode.push(Opcode::And);
286            }
287            VariableExpression::Or(lhs, rhs) => {
288                compile(*lhs, bytecode, variables);
289                compile(*rhs, bytecode, variables);
290                bytecode.push(Opcode::Or);
291            }
292            VariableExpression::Not(inner) => {
293                compile(*inner, bytecode, variables);
294                bytecode.push(Opcode::Not);
295            }
296        }
297    }
298
299    compile(expr, &mut bytecode, &mut variables);
300
301    CompiledVariableExpression {
302        bytecode,
303        variables,
304    }
305}