1use std::collections::{HashMap, HashSet};
2
3use indexmap::IndexMap;
4use rayon::prelude::*;
5use smallvec::SmallVec;
6
7use crate::error::FlowGateError;
8use crate::event::{EventMatrix, EventMatrixView};
9use crate::gate::{
10 BooleanGate, BooleanOp, EllipsoidDimension, EllipsoidGate, PolygonDimension, PolygonGate,
11 RectangleDimension, RectangleGate,
12};
13use crate::traits::{ApplyGate, BitVec, Gate, GateId, ParameterName, Transform};
14use crate::transform::TransformKind;
15
16#[allow(clippy::large_enum_variant)]
19#[derive(Debug, Clone)]
20pub enum GateKind {
21 Rectangle(Box<RectangleGate>),
22 Polygon(PolygonGate),
23 Ellipsoid(EllipsoidGate),
24 Boolean(BooleanGate),
25}
26
27impl GateKind {
28 fn dependency_ids(&self) -> Vec<GateId> {
29 let mut deps = Vec::new();
30 if let Some(parent) = self.parent_id() {
31 deps.push(parent.clone());
32 }
33 if let Self::Boolean(g) = self {
34 for op in g.operands() {
35 deps.push(op.gate_id.clone());
36 }
37 }
38 deps
39 }
40
41 fn transforms(&self) -> SmallVec<[Option<TransformKind>; 8]> {
42 match self {
43 Self::Rectangle(g) => g
44 .rectangle_dimensions()
45 .iter()
46 .map(|d: &RectangleDimension| d.transform)
47 .collect(),
48 Self::Polygon(g) => {
49 let dims: [&PolygonDimension; 2] = [g.x_dim(), g.y_dim()];
50 dims.iter().map(|d| d.transform).collect()
51 }
52 Self::Ellipsoid(g) => g
53 .dimensions_def()
54 .iter()
55 .map(|d: &EllipsoidDimension| d.transform)
56 .collect(),
57 Self::Boolean(_) => SmallVec::new(),
58 }
59 }
60}
61
62impl Gate for GateKind {
63 fn dimensions(&self) -> &[ParameterName] {
64 match self {
65 Self::Rectangle(g) => g.dimensions(),
66 Self::Polygon(g) => g.dimensions(),
67 Self::Ellipsoid(g) => g.dimensions(),
68 Self::Boolean(g) => g.dimensions(),
69 }
70 }
71
72 fn contains(&self, coords: &[f64]) -> bool {
73 match self {
74 Self::Rectangle(g) => g.contains(coords),
75 Self::Polygon(g) => g.contains(coords),
76 Self::Ellipsoid(g) => g.contains(coords),
77 Self::Boolean(g) => g.contains(coords),
78 }
79 }
80
81 fn gate_id(&self) -> &GateId {
82 match self {
83 Self::Rectangle(g) => g.gate_id(),
84 Self::Polygon(g) => g.gate_id(),
85 Self::Ellipsoid(g) => g.gate_id(),
86 Self::Boolean(g) => g.gate_id(),
87 }
88 }
89
90 fn parent_id(&self) -> Option<&GateId> {
91 match self {
92 Self::Rectangle(g) => g.parent_id(),
93 Self::Polygon(g) => g.parent_id(),
94 Self::Ellipsoid(g) => g.parent_id(),
95 Self::Boolean(g) => g.parent_id(),
96 }
97 }
98}
99
100impl ApplyGate for GateKind {
101 fn classify(
102 &self,
103 matrix: &EventMatrix,
104 gate_map: &GateRegistry,
105 ) -> Result<BitVec, FlowGateError> {
106 let results = gate_map.classify_all(matrix)?;
107 Ok(results
108 .get(self.gate_id())
109 .cloned()
110 .unwrap_or_else(|| BitVec::repeat(false, matrix.n_events)))
111 }
112}
113
114#[derive(Debug, Default, Clone)]
115pub struct GateRegistry {
116 gates: IndexMap<GateId, GateKind>,
117 topo_order: Vec<GateId>,
118}
119
120impl GateRegistry {
121 pub fn new(gates: IndexMap<GateId, GateKind>) -> Result<Self, FlowGateError> {
122 let topo_order = compute_topological_order(&gates)?;
123 Ok(Self { gates, topo_order })
124 }
125
126 pub fn iter(&self) -> impl Iterator<Item = (&GateId, &GateKind)> {
127 self.gates.iter()
128 }
129
130 pub fn get(&self, gate_id: &GateId) -> Option<&GateKind> {
131 self.gates.get(gate_id)
132 }
133
134 pub fn insert(&mut self, gate_id: GateId, gate: GateKind) -> Result<(), FlowGateError> {
135 self.gates.insert(gate_id, gate);
136 self.topo_order = compute_topological_order(&self.gates)?;
137 Ok(())
138 }
139
140 pub fn topological_order(&self) -> &[GateId] {
141 &self.topo_order
142 }
143
144 pub fn classify_all(
145 &self,
146 matrix: &EventMatrix,
147 ) -> Result<HashMap<GateId, BitVec>, FlowGateError> {
148 let n_events = matrix.n_events;
149 let mut results: HashMap<GateId, BitVec> = HashMap::with_capacity(self.gates.len());
150
151 for gate_id in &self.topo_order {
152 let gate = self.gates.get(gate_id).ok_or_else(|| {
153 FlowGateError::UnknownGateReference(gate_id.clone(), gate_id.clone())
154 })?;
155
156 let membership = match gate {
157 GateKind::Boolean(boolean_gate) => {
158 evaluate_boolean_gate(boolean_gate, &results, n_events)?
159 }
160 _ => classify_spatial_gate(gate, matrix)?,
161 };
162
163 let membership = if let Some(parent_id) = gate.parent_id() {
164 let parent_bits = results
165 .get(parent_id)
166 .ok_or_else(|| FlowGateError::MissingParentGate(parent_id.clone()))?;
167 let mut child = membership;
168 for idx in 0..n_events {
169 let keep = child[idx] & parent_bits[idx];
170 child.set(idx, keep);
171 }
172 child
173 } else {
174 membership
175 };
176
177 results.insert(gate_id.clone(), membership);
178 }
179
180 Ok(results)
181 }
182
183 pub fn classify_all_view(
184 &self,
185 matrix: &EventMatrixView<'_>,
186 ) -> Result<HashMap<GateId, BitVec>, FlowGateError> {
187 let n_events = matrix.n_events;
188 let mut results: HashMap<GateId, BitVec> = HashMap::with_capacity(self.gates.len());
189
190 for gate_id in &self.topo_order {
191 let gate = self.gates.get(gate_id).ok_or_else(|| {
192 FlowGateError::UnknownGateReference(gate_id.clone(), gate_id.clone())
193 })?;
194
195 let membership = match gate {
196 GateKind::Boolean(boolean_gate) => {
197 evaluate_boolean_gate(boolean_gate, &results, n_events)?
198 }
199 _ => classify_spatial_gate_view(gate, matrix)?,
200 };
201
202 let membership = if let Some(parent_id) = gate.parent_id() {
203 let parent_bits = results
204 .get(parent_id)
205 .ok_or_else(|| FlowGateError::MissingParentGate(parent_id.clone()))?;
206 let mut child = membership;
207 for idx in 0..n_events {
208 let keep = child[idx] & parent_bits[idx];
209 child.set(idx, keep);
210 }
211 child
212 } else {
213 membership
214 };
215
216 results.insert(gate_id.clone(), membership);
217 }
218
219 Ok(results)
220 }
221}
222
223fn classify_spatial_gate(gate: &GateKind, matrix: &EventMatrix) -> Result<BitVec, FlowGateError> {
224 let n_events = matrix.n_events;
225 let projected = matrix.project(gate.dimensions())?;
226 let columns = projected.columns();
227 let transforms = gate.transforms();
228
229 let bools: Vec<bool> = (0..n_events)
230 .into_par_iter()
231 .map(|event_idx| {
232 let mut coords = SmallVec::<[f64; 8]>::with_capacity(columns.len());
233 for (dim_idx, col) in columns.iter().enumerate() {
234 let raw = col[event_idx];
235 let value = transforms
236 .get(dim_idx)
237 .copied()
238 .flatten()
239 .map_or(raw, |t| t.apply(raw));
240 coords.push(value);
241 }
242 gate.contains(&coords)
243 })
244 .collect();
245
246 Ok(bools.into_iter().collect())
247}
248
249fn classify_spatial_gate_view(
250 gate: &GateKind,
251 matrix: &EventMatrixView<'_>,
252) -> Result<BitVec, FlowGateError> {
253 let n_events = matrix.n_events;
254 let dim_indices = matrix.project_indices(gate.dimensions())?;
255 let transforms = gate.transforms();
256
257 let bools: Vec<bool> = (0..n_events)
258 .into_par_iter()
259 .map(|event_idx| {
260 let mut coords = SmallVec::<[f64; 8]>::with_capacity(dim_indices.len());
261 for (dim_idx, ¶m_idx) in dim_indices.iter().enumerate() {
262 let raw = matrix.value_at(event_idx, param_idx).unwrap_or(f64::NAN);
263 let value = transforms
264 .get(dim_idx)
265 .copied()
266 .flatten()
267 .map_or(raw, |t| t.apply(raw));
268 coords.push(value);
269 }
270 gate.contains(&coords)
271 })
272 .collect();
273
274 Ok(bools.into_iter().collect())
275}
276
277fn evaluate_boolean_gate(
278 gate: &BooleanGate,
279 results: &HashMap<GateId, BitVec>,
280 n_events: usize,
281) -> Result<BitVec, FlowGateError> {
282 let mut operand_bits = Vec::with_capacity(gate.operands().len());
283 for operand in gate.operands() {
284 let source = results.get(&operand.gate_id).ok_or_else(|| {
285 FlowGateError::UnknownGateReference(gate.gate_id().clone(), operand.gate_id.clone())
286 })?;
287 let mut bits = source.clone();
288 if operand.complement {
289 for idx in 0..n_events {
290 let prev = bits[idx];
291 bits.set(idx, !prev);
292 }
293 }
294 operand_bits.push(bits);
295 }
296
297 let out = match gate.op() {
298 BooleanOp::And => {
299 let mut acc = BitVec::repeat(true, n_events);
300 for bits in &operand_bits {
301 for idx in 0..n_events {
302 let prev = acc[idx];
303 acc.set(idx, prev & bits[idx]);
304 }
305 }
306 acc
307 }
308 BooleanOp::Or => {
309 let mut acc = BitVec::repeat(false, n_events);
310 for bits in &operand_bits {
311 for idx in 0..n_events {
312 let prev = acc[idx];
313 acc.set(idx, prev | bits[idx]);
314 }
315 }
316 acc
317 }
318 BooleanOp::Not => {
319 if operand_bits.len() != 1 {
320 return Err(FlowGateError::BooleanNotArity(
321 gate.gate_id().clone(),
322 operand_bits.len(),
323 ));
324 }
325 let mut acc = operand_bits.remove(0);
326 for idx in 0..n_events {
327 let prev = acc[idx];
328 acc.set(idx, !prev);
329 }
330 acc
331 }
332 };
333 Ok(out)
334}
335
336fn compute_topological_order(
337 gates: &IndexMap<GateId, GateKind>,
338) -> Result<Vec<GateId>, FlowGateError> {
339 #[derive(Clone, Copy, PartialEq, Eq)]
340 enum Mark {
341 Temp,
342 Perm,
343 }
344
345 fn visit(
346 node: &GateId,
347 gates: &IndexMap<GateId, GateKind>,
348 marks: &mut HashMap<GateId, Mark>,
349 order: &mut Vec<GateId>,
350 stack: &mut HashSet<GateId>,
351 ) -> Result<(), FlowGateError> {
352 if marks.get(node) == Some(&Mark::Perm) {
353 return Ok(());
354 }
355 if marks.get(node) == Some(&Mark::Temp) || stack.contains(node) {
356 return Err(FlowGateError::CyclicGateReference(node.clone()));
357 }
358
359 marks.insert(node.clone(), Mark::Temp);
360 stack.insert(node.clone());
361
362 let gate = gates
363 .get(node)
364 .ok_or_else(|| FlowGateError::UnknownGateReference(node.clone(), node.clone()))?;
365 for dep in gate.dependency_ids() {
366 if !gates.contains_key(&dep) {
367 return Err(FlowGateError::UnknownGateReference(node.clone(), dep));
368 }
369 visit(&dep, gates, marks, order, stack)?;
370 }
371
372 marks.insert(node.clone(), Mark::Perm);
373 stack.remove(node);
374 order.push(node.clone());
375 Ok(())
376 }
377
378 let mut marks: HashMap<GateId, Mark> = HashMap::new();
379 let mut stack: HashSet<GateId> = HashSet::new();
380 let mut order = Vec::with_capacity(gates.len());
381 for gate_id in gates.keys() {
382 visit(gate_id, gates, &mut marks, &mut order, &mut stack)?;
383 }
384 Ok(order)
385}