Skip to main content

flow_gate_core/gate/
registry.rs

1use std::collections::{HashMap, HashSet};
2
3use indexmap::IndexMap;
4use rayon::prelude::*;
5use smallvec::SmallVec;
6
7use crate::error::FlowGateError;
8use crate::event::{EventMatrix, EventMatrixView};
9use crate::gate::{
10    BooleanGate, BooleanOp, EllipsoidDimension, EllipsoidGate, PolygonDimension, PolygonGate,
11    RectangleDimension, RectangleGate,
12};
13use crate::traits::{ApplyGate, BitVec, Gate, GateId, ParameterName, Transform};
14use crate::transform::TransformKind;
15
16// All four variants are required by the Gating-ML 2.0 spec.
17// The enum is used behind references in all hot paths, so size is not a concern.
18#[allow(clippy::large_enum_variant)]
19#[derive(Debug, Clone)]
20pub enum GateKind {
21    Rectangle(Box<RectangleGate>),
22    Polygon(PolygonGate),
23    Ellipsoid(EllipsoidGate),
24    Boolean(BooleanGate),
25}
26
27impl GateKind {
28    fn dependency_ids(&self) -> Vec<GateId> {
29        let mut deps = Vec::new();
30        if let Some(parent) = self.parent_id() {
31            deps.push(parent.clone());
32        }
33        if let Self::Boolean(g) = self {
34            for op in g.operands() {
35                deps.push(op.gate_id.clone());
36            }
37        }
38        deps
39    }
40
41    fn transforms(&self) -> SmallVec<[Option<TransformKind>; 8]> {
42        match self {
43            Self::Rectangle(g) => g
44                .rectangle_dimensions()
45                .iter()
46                .map(|d: &RectangleDimension| d.transform)
47                .collect(),
48            Self::Polygon(g) => {
49                let dims: [&PolygonDimension; 2] = [g.x_dim(), g.y_dim()];
50                dims.iter().map(|d| d.transform).collect()
51            }
52            Self::Ellipsoid(g) => g
53                .dimensions_def()
54                .iter()
55                .map(|d: &EllipsoidDimension| d.transform)
56                .collect(),
57            Self::Boolean(_) => SmallVec::new(),
58        }
59    }
60}
61
62impl Gate for GateKind {
63    fn dimensions(&self) -> &[ParameterName] {
64        match self {
65            Self::Rectangle(g) => g.dimensions(),
66            Self::Polygon(g) => g.dimensions(),
67            Self::Ellipsoid(g) => g.dimensions(),
68            Self::Boolean(g) => g.dimensions(),
69        }
70    }
71
72    fn contains(&self, coords: &[f64]) -> bool {
73        match self {
74            Self::Rectangle(g) => g.contains(coords),
75            Self::Polygon(g) => g.contains(coords),
76            Self::Ellipsoid(g) => g.contains(coords),
77            Self::Boolean(g) => g.contains(coords),
78        }
79    }
80
81    fn gate_id(&self) -> &GateId {
82        match self {
83            Self::Rectangle(g) => g.gate_id(),
84            Self::Polygon(g) => g.gate_id(),
85            Self::Ellipsoid(g) => g.gate_id(),
86            Self::Boolean(g) => g.gate_id(),
87        }
88    }
89
90    fn parent_id(&self) -> Option<&GateId> {
91        match self {
92            Self::Rectangle(g) => g.parent_id(),
93            Self::Polygon(g) => g.parent_id(),
94            Self::Ellipsoid(g) => g.parent_id(),
95            Self::Boolean(g) => g.parent_id(),
96        }
97    }
98}
99
100impl ApplyGate for GateKind {
101    fn classify(
102        &self,
103        matrix: &EventMatrix,
104        gate_map: &GateRegistry,
105    ) -> Result<BitVec, FlowGateError> {
106        let results = gate_map.classify_all(matrix)?;
107        Ok(results
108            .get(self.gate_id())
109            .cloned()
110            .unwrap_or_else(|| BitVec::repeat(false, matrix.n_events)))
111    }
112}
113
114#[derive(Debug, Default, Clone)]
115pub struct GateRegistry {
116    gates: IndexMap<GateId, GateKind>,
117    topo_order: Vec<GateId>,
118}
119
120impl GateRegistry {
121    pub fn new(gates: IndexMap<GateId, GateKind>) -> Result<Self, FlowGateError> {
122        let topo_order = compute_topological_order(&gates)?;
123        Ok(Self { gates, topo_order })
124    }
125
126    pub fn iter(&self) -> impl Iterator<Item = (&GateId, &GateKind)> {
127        self.gates.iter()
128    }
129
130    pub fn get(&self, gate_id: &GateId) -> Option<&GateKind> {
131        self.gates.get(gate_id)
132    }
133
134    pub fn insert(&mut self, gate_id: GateId, gate: GateKind) -> Result<(), FlowGateError> {
135        self.gates.insert(gate_id, gate);
136        self.topo_order = compute_topological_order(&self.gates)?;
137        Ok(())
138    }
139
140    pub fn topological_order(&self) -> &[GateId] {
141        &self.topo_order
142    }
143
144    pub fn classify_all(
145        &self,
146        matrix: &EventMatrix,
147    ) -> Result<HashMap<GateId, BitVec>, FlowGateError> {
148        let n_events = matrix.n_events;
149        let mut results: HashMap<GateId, BitVec> = HashMap::with_capacity(self.gates.len());
150
151        for gate_id in &self.topo_order {
152            let gate = self.gates.get(gate_id).ok_or_else(|| {
153                FlowGateError::UnknownGateReference(gate_id.clone(), gate_id.clone())
154            })?;
155
156            let membership = match gate {
157                GateKind::Boolean(boolean_gate) => {
158                    evaluate_boolean_gate(boolean_gate, &results, n_events)?
159                }
160                _ => classify_spatial_gate(gate, matrix)?,
161            };
162
163            let membership = if let Some(parent_id) = gate.parent_id() {
164                let parent_bits = results
165                    .get(parent_id)
166                    .ok_or_else(|| FlowGateError::MissingParentGate(parent_id.clone()))?;
167                let mut child = membership;
168                for idx in 0..n_events {
169                    let keep = child[idx] & parent_bits[idx];
170                    child.set(idx, keep);
171                }
172                child
173            } else {
174                membership
175            };
176
177            results.insert(gate_id.clone(), membership);
178        }
179
180        Ok(results)
181    }
182
183    pub fn classify_all_view(
184        &self,
185        matrix: &EventMatrixView<'_>,
186    ) -> Result<HashMap<GateId, BitVec>, FlowGateError> {
187        let n_events = matrix.n_events;
188        let mut results: HashMap<GateId, BitVec> = HashMap::with_capacity(self.gates.len());
189
190        for gate_id in &self.topo_order {
191            let gate = self.gates.get(gate_id).ok_or_else(|| {
192                FlowGateError::UnknownGateReference(gate_id.clone(), gate_id.clone())
193            })?;
194
195            let membership = match gate {
196                GateKind::Boolean(boolean_gate) => {
197                    evaluate_boolean_gate(boolean_gate, &results, n_events)?
198                }
199                _ => classify_spatial_gate_view(gate, matrix)?,
200            };
201
202            let membership = if let Some(parent_id) = gate.parent_id() {
203                let parent_bits = results
204                    .get(parent_id)
205                    .ok_or_else(|| FlowGateError::MissingParentGate(parent_id.clone()))?;
206                let mut child = membership;
207                for idx in 0..n_events {
208                    let keep = child[idx] & parent_bits[idx];
209                    child.set(idx, keep);
210                }
211                child
212            } else {
213                membership
214            };
215
216            results.insert(gate_id.clone(), membership);
217        }
218
219        Ok(results)
220    }
221}
222
223fn classify_spatial_gate(gate: &GateKind, matrix: &EventMatrix) -> Result<BitVec, FlowGateError> {
224    let n_events = matrix.n_events;
225    let projected = matrix.project(gate.dimensions())?;
226    let columns = projected.columns();
227    let transforms = gate.transforms();
228
229    let bools: Vec<bool> = (0..n_events)
230        .into_par_iter()
231        .map(|event_idx| {
232            let mut coords = SmallVec::<[f64; 8]>::with_capacity(columns.len());
233            for (dim_idx, col) in columns.iter().enumerate() {
234                let raw = col[event_idx];
235                let value = transforms
236                    .get(dim_idx)
237                    .copied()
238                    .flatten()
239                    .map_or(raw, |t| t.apply(raw));
240                coords.push(value);
241            }
242            gate.contains(&coords)
243        })
244        .collect();
245
246    Ok(bools.into_iter().collect())
247}
248
249fn classify_spatial_gate_view(
250    gate: &GateKind,
251    matrix: &EventMatrixView<'_>,
252) -> Result<BitVec, FlowGateError> {
253    let n_events = matrix.n_events;
254    let dim_indices = matrix.project_indices(gate.dimensions())?;
255    let transforms = gate.transforms();
256
257    let bools: Vec<bool> = (0..n_events)
258        .into_par_iter()
259        .map(|event_idx| {
260            let mut coords = SmallVec::<[f64; 8]>::with_capacity(dim_indices.len());
261            for (dim_idx, &param_idx) in dim_indices.iter().enumerate() {
262                let raw = matrix.value_at(event_idx, param_idx).unwrap_or(f64::NAN);
263                let value = transforms
264                    .get(dim_idx)
265                    .copied()
266                    .flatten()
267                    .map_or(raw, |t| t.apply(raw));
268                coords.push(value);
269            }
270            gate.contains(&coords)
271        })
272        .collect();
273
274    Ok(bools.into_iter().collect())
275}
276
277fn evaluate_boolean_gate(
278    gate: &BooleanGate,
279    results: &HashMap<GateId, BitVec>,
280    n_events: usize,
281) -> Result<BitVec, FlowGateError> {
282    let mut operand_bits = Vec::with_capacity(gate.operands().len());
283    for operand in gate.operands() {
284        let source = results.get(&operand.gate_id).ok_or_else(|| {
285            FlowGateError::UnknownGateReference(gate.gate_id().clone(), operand.gate_id.clone())
286        })?;
287        let mut bits = source.clone();
288        if operand.complement {
289            for idx in 0..n_events {
290                let prev = bits[idx];
291                bits.set(idx, !prev);
292            }
293        }
294        operand_bits.push(bits);
295    }
296
297    let out = match gate.op() {
298        BooleanOp::And => {
299            let mut acc = BitVec::repeat(true, n_events);
300            for bits in &operand_bits {
301                for idx in 0..n_events {
302                    let prev = acc[idx];
303                    acc.set(idx, prev & bits[idx]);
304                }
305            }
306            acc
307        }
308        BooleanOp::Or => {
309            let mut acc = BitVec::repeat(false, n_events);
310            for bits in &operand_bits {
311                for idx in 0..n_events {
312                    let prev = acc[idx];
313                    acc.set(idx, prev | bits[idx]);
314                }
315            }
316            acc
317        }
318        BooleanOp::Not => {
319            if operand_bits.len() != 1 {
320                return Err(FlowGateError::BooleanNotArity(
321                    gate.gate_id().clone(),
322                    operand_bits.len(),
323                ));
324            }
325            let mut acc = operand_bits.remove(0);
326            for idx in 0..n_events {
327                let prev = acc[idx];
328                acc.set(idx, !prev);
329            }
330            acc
331        }
332    };
333    Ok(out)
334}
335
336fn compute_topological_order(
337    gates: &IndexMap<GateId, GateKind>,
338) -> Result<Vec<GateId>, FlowGateError> {
339    #[derive(Clone, Copy, PartialEq, Eq)]
340    enum Mark {
341        Temp,
342        Perm,
343    }
344
345    fn visit(
346        node: &GateId,
347        gates: &IndexMap<GateId, GateKind>,
348        marks: &mut HashMap<GateId, Mark>,
349        order: &mut Vec<GateId>,
350        stack: &mut HashSet<GateId>,
351    ) -> Result<(), FlowGateError> {
352        if marks.get(node) == Some(&Mark::Perm) {
353            return Ok(());
354        }
355        if marks.get(node) == Some(&Mark::Temp) || stack.contains(node) {
356            return Err(FlowGateError::CyclicGateReference(node.clone()));
357        }
358
359        marks.insert(node.clone(), Mark::Temp);
360        stack.insert(node.clone());
361
362        let gate = gates
363            .get(node)
364            .ok_or_else(|| FlowGateError::UnknownGateReference(node.clone(), node.clone()))?;
365        for dep in gate.dependency_ids() {
366            if !gates.contains_key(&dep) {
367                return Err(FlowGateError::UnknownGateReference(node.clone(), dep));
368            }
369            visit(&dep, gates, marks, order, stack)?;
370        }
371
372        marks.insert(node.clone(), Mark::Perm);
373        stack.remove(node);
374        order.push(node.clone());
375        Ok(())
376    }
377
378    let mut marks: HashMap<GateId, Mark> = HashMap::new();
379    let mut stack: HashSet<GateId> = HashSet::new();
380    let mut order = Vec::with_capacity(gates.len());
381    for gate_id in gates.keys() {
382        visit(gate_id, gates, &mut marks, &mut order, &mut stack)?;
383    }
384    Ok(order)
385}