1use std::fmt::{Debug, Display};
2
3use auto_ops::impl_op_ex;
4use dyn_clone::DynClone;
5#[cfg(feature = "mpi")]
6use mpi::{datatype::PartitionMut, topology::SimpleCommunicator, traits::*};
7#[cfg(feature = "rayon")]
8use rayon::prelude::*;
9
10#[cfg(feature = "mpi")]
11use crate::mpi::LadduMPI;
12use crate::{
13 data::{Dataset, DatasetMetadata, EventLike},
14 LadduResult,
15};
16
17#[typetag::serde(tag = "type")]
19pub trait Variable: DynClone + Send + Sync + Debug + Display {
20 fn bind(&mut self, _metadata: &DatasetMetadata) -> LadduResult<()> {
24 Ok(())
25 }
26
27 fn value(&self, event: &dyn EventLike) -> f64;
29
30 fn value_on_local(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
37 let mut variable = dyn_clone::clone_box(self);
38 variable.bind(dataset.metadata())?;
39 #[cfg(feature = "rayon")]
40 let local_values: Vec<f64> = (0..dataset.n_events_local())
41 .into_par_iter()
42 .map(|event_index| {
43 let event = dataset.event_view(event_index);
44 variable.value(&event)
45 })
46 .collect();
47 #[cfg(not(feature = "rayon"))]
48 let local_values: Vec<f64> = (0..dataset.n_events_local())
49 .map(|event_index| {
50 let event = dataset.event_view(event_index);
51 variable.value(&event)
52 })
53 .collect();
54 Ok(local_values)
55 }
56
57 #[cfg(feature = "mpi")]
65 fn value_on_mpi(&self, dataset: &Dataset, world: &SimpleCommunicator) -> LadduResult<Vec<f64>> {
66 let local_weights = self.value_on_local(dataset)?;
67 let n_events = dataset.n_events();
68 let mut buffer: Vec<f64> = vec![0.0; n_events];
69 let (counts, displs) = world.get_counts_displs(n_events);
70 {
71 let mut partitioned_buffer = PartitionMut::new(&mut buffer, counts, displs);
72 world.all_gather_varcount_into(&local_weights, &mut partitioned_buffer);
73 }
74 Ok(buffer)
75 }
76
77 fn value_on(&self, dataset: &Dataset) -> LadduResult<Vec<f64>> {
80 #[cfg(feature = "mpi")]
81 {
82 if let Some(world) = crate::mpi::get_world() {
83 return self.value_on_mpi(dataset, &world);
84 }
85 }
86 self.value_on_local(dataset)
87 }
88
89 fn eq(&self, val: f64) -> VariableExpression
91 where
92 Self: std::marker::Sized + 'static,
93 {
94 VariableExpression::Eq(dyn_clone::clone_box(self), val)
95 }
96
97 fn lt(&self, val: f64) -> VariableExpression
99 where
100 Self: std::marker::Sized + 'static,
101 {
102 VariableExpression::Lt(dyn_clone::clone_box(self), val)
103 }
104
105 fn gt(&self, val: f64) -> VariableExpression
107 where
108 Self: std::marker::Sized + 'static,
109 {
110 VariableExpression::Gt(dyn_clone::clone_box(self), val)
111 }
112
113 fn ge(&self, val: f64) -> VariableExpression
115 where
116 Self: std::marker::Sized + 'static,
117 {
118 self.gt(val).or(&self.eq(val))
119 }
120
121 fn le(&self, val: f64) -> VariableExpression
123 where
124 Self: std::marker::Sized + 'static,
125 {
126 self.lt(val).or(&self.eq(val))
127 }
128}
129dyn_clone::clone_trait_object!(Variable);
130
131#[derive(Clone, Debug)]
133pub enum VariableExpression {
134 Eq(Box<dyn Variable>, f64),
136 Lt(Box<dyn Variable>, f64),
138 Gt(Box<dyn Variable>, f64),
140 And(Box<VariableExpression>, Box<VariableExpression>),
142 Or(Box<VariableExpression>, Box<VariableExpression>),
144 Not(Box<VariableExpression>),
146}
147
148impl VariableExpression {
149 pub fn and(&self, rhs: &VariableExpression) -> VariableExpression {
151 VariableExpression::And(Box::new(self.clone()), Box::new(rhs.clone()))
152 }
153
154 pub fn or(&self, rhs: &VariableExpression) -> VariableExpression {
156 VariableExpression::Or(Box::new(self.clone()), Box::new(rhs.clone()))
157 }
158
159 pub(crate) fn compile(
162 &self,
163 metadata: &DatasetMetadata,
164 ) -> LadduResult<CompiledVariableExpression> {
165 let mut compiled = compile_expression(self.clone());
166 compiled.bind(metadata)?;
167 Ok(compiled)
168 }
169}
170impl Display for VariableExpression {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 match self {
173 VariableExpression::Eq(var, val) => {
174 write!(f, "({} == {})", var, val)
175 }
176 VariableExpression::Lt(var, val) => {
177 write!(f, "({} < {})", var, val)
178 }
179 VariableExpression::Gt(var, val) => {
180 write!(f, "({} > {})", var, val)
181 }
182 VariableExpression::And(lhs, rhs) => {
183 write!(f, "({} & {})", lhs, rhs)
184 }
185 VariableExpression::Or(lhs, rhs) => {
186 write!(f, "({} | {})", lhs, rhs)
187 }
188 VariableExpression::Not(inner) => {
189 write!(f, "!({})", inner)
190 }
191 }
192 }
193}
194
195pub fn not(expr: &VariableExpression) -> VariableExpression {
197 VariableExpression::Not(Box::new(expr.clone()))
198}
199
200#[rustfmt::skip]
201impl_op_ex!(& |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.and(rhs) });
202#[rustfmt::skip]
203impl_op_ex!(| |lhs: &VariableExpression, rhs: &VariableExpression| -> VariableExpression{ lhs.or(rhs) });
204#[rustfmt::skip]
205impl_op_ex!(! |exp: &VariableExpression| -> VariableExpression{ not(exp) });
206
207#[derive(Debug)]
208enum Opcode {
209 PushEq(usize, f64),
210 PushLt(usize, f64),
211 PushGt(usize, f64),
212 And,
213 Or,
214 Not,
215}
216
217pub(crate) struct CompiledVariableExpression {
218 bytecode: Vec<Opcode>,
219 variables: Vec<Box<dyn Variable>>,
220}
221
222impl CompiledVariableExpression {
223 pub fn bind(&mut self, metadata: &DatasetMetadata) -> LadduResult<()> {
224 for variable in &mut self.variables {
225 variable.bind(metadata)?;
226 }
227 Ok(())
228 }
229
230 pub fn evaluate(&self, event: &dyn EventLike) -> bool {
232 let mut stack = Vec::with_capacity(self.bytecode.len());
233
234 for op in &self.bytecode {
235 match op {
236 Opcode::PushEq(i, val) => stack.push(self.variables[*i].value(event) == *val),
237 Opcode::PushLt(i, val) => stack.push(self.variables[*i].value(event) < *val),
238 Opcode::PushGt(i, val) => stack.push(self.variables[*i].value(event) > *val),
239 Opcode::Not => {
240 let a = stack.pop().unwrap();
241 stack.push(!a);
242 }
243 Opcode::And => {
244 let b = stack.pop().unwrap();
245 let a = stack.pop().unwrap();
246 stack.push(a && b);
247 }
248 Opcode::Or => {
249 let b = stack.pop().unwrap();
250 let a = stack.pop().unwrap();
251 stack.push(a || b);
252 }
253 }
254 }
255
256 stack.pop().unwrap()
257 }
258}
259
260pub(crate) fn compile_expression(expr: VariableExpression) -> CompiledVariableExpression {
261 let mut bytecode = Vec::new();
262 let mut variables: Vec<Box<dyn Variable>> = Vec::new();
263
264 fn compile(
265 expr: VariableExpression,
266 bytecode: &mut Vec<Opcode>,
267 variables: &mut Vec<Box<dyn Variable>>,
268 ) {
269 match expr {
270 VariableExpression::Eq(var, val) => {
271 variables.push(var);
272 bytecode.push(Opcode::PushEq(variables.len() - 1, val));
273 }
274 VariableExpression::Lt(var, val) => {
275 variables.push(var);
276 bytecode.push(Opcode::PushLt(variables.len() - 1, val));
277 }
278 VariableExpression::Gt(var, val) => {
279 variables.push(var);
280 bytecode.push(Opcode::PushGt(variables.len() - 1, val));
281 }
282 VariableExpression::And(lhs, rhs) => {
283 compile(*lhs, bytecode, variables);
284 compile(*rhs, bytecode, variables);
285 bytecode.push(Opcode::And);
286 }
287 VariableExpression::Or(lhs, rhs) => {
288 compile(*lhs, bytecode, variables);
289 compile(*rhs, bytecode, variables);
290 bytecode.push(Opcode::Or);
291 }
292 VariableExpression::Not(inner) => {
293 compile(*inner, bytecode, variables);
294 bytecode.push(Opcode::Not);
295 }
296 }
297 }
298
299 compile(expr, &mut bytecode, &mut variables);
300
301 CompiledVariableExpression {
302 bytecode,
303 variables,
304 }
305}