Skip to main content

flow_gate_core/
event.rs

1use std::collections::HashMap;
2use std::marker::PhantomData;
3
4use rayon::prelude::*;
5use smallvec::SmallVec;
6
7use crate::error::FlowGateError;
8use crate::traits::{ParameterName, Transform};
9use crate::transform::TransformKind;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MatrixLayout {
13    RowMajor,
14    ColMajor,
15}
16
17/// Non-owning read-only matrix view used by FFI bridges.
18///
19/// # Safety Contract
20/// The caller must guarantee that:
21/// - `ptr` points to at least `n_rows * n_cols` valid, properly aligned `f64` values
22/// - The underlying data remains valid for the lifetime `'a`
23/// - The underlying data is not being mutated concurrently while this view is in use
24#[derive(Debug, Clone, Copy)]
25pub struct MatrixView<'a> {
26    pub ptr: *const f64,
27    pub n_rows: usize,
28    pub n_cols: usize,
29    pub layout: MatrixLayout,
30    _lifetime: PhantomData<&'a f64>,
31}
32
33// SAFETY: MatrixView only provides read-only access through shared references.
34// The safety contract requires that the underlying data is valid for the
35// declared lifetime and is not being concurrently mutated. When constructed
36// from a `&[f64]` or `&Vec<f64>`, these invariants are upheld by Rust's
37// borrowing rules. FFI callers must independently guarantee thread safety.
38unsafe impl<'a> Send for MatrixView<'a> {}
39// SAFETY: Multiple threads may concurrently read immutable data.
40// See Send safety rationale above.
41unsafe impl<'a> Sync for MatrixView<'a> {}
42
43impl<'a> MatrixView<'a> {
44    /// Creates a non-owning read-only view over a raw `f64` buffer.
45    ///
46    /// # Safety
47    /// - `ptr` must point to at least `n_rows * n_cols` valid, properly aligned
48    ///   `f64` values.
49    /// - The underlying data must remain valid for the lifetime `'a` of this view.
50    /// - The underlying data must not be concurrently mutated while this view is in use.
51    pub unsafe fn from_raw(
52        ptr: *const f64,
53        n_rows: usize,
54        n_cols: usize,
55        layout: MatrixLayout,
56    ) -> Self {
57        Self {
58            ptr,
59            n_rows,
60            n_cols,
61            layout,
62            _lifetime: PhantomData,
63        }
64    }
65
66    /// Returns the value at `(row, col)` without bounds checking.
67    ///
68    /// # Safety
69    /// Caller must guarantee that `row < n_rows` and `col < n_cols`.
70    #[inline]
71    pub unsafe fn get_unchecked(&self, row: usize, col: usize) -> f64 {
72        match self.layout {
73            MatrixLayout::RowMajor => *self.ptr.add(row * self.n_cols + col),
74            MatrixLayout::ColMajor => *self.ptr.add(col * self.n_rows + row),
75        }
76    }
77
78    pub fn column(&self, col: usize) -> ColumnIter<'a, '_> {
79        ColumnIter {
80            view: self,
81            col,
82            row: 0,
83        }
84    }
85}
86
87pub struct ColumnIter<'a, 'v> {
88    view: &'v MatrixView<'a>,
89    col: usize,
90    row: usize,
91}
92
93impl<'a, 'v> Iterator for ColumnIter<'a, 'v> {
94    type Item = f64;
95
96    fn next(&mut self) -> Option<Self::Item> {
97        if self.row >= self.view.n_rows {
98            return None;
99        }
100        let row = self.row;
101        self.row += 1;
102        // SAFETY: bounds checked above.
103        Some(unsafe { self.view.get_unchecked(row, self.col) })
104    }
105}
106
107pub struct EventMatrixView<'a> {
108    view: MatrixView<'a>,
109    pub n_events: usize,
110    pub n_params: usize,
111    param_names: Vec<ParameterName>,
112    param_index: HashMap<ParameterName, usize>,
113}
114
115impl<'a> EventMatrixView<'a> {
116    pub fn project_indices(
117        &self,
118        names: &[ParameterName],
119    ) -> Result<SmallVec<[usize; 8]>, FlowGateError> {
120        let mut indices = SmallVec::<[usize; 8]>::with_capacity(names.len());
121        for name in names {
122            let Some(&idx) = self.param_index.get(name) else {
123                return Err(FlowGateError::UnknownParameter(name.clone()));
124            };
125            indices.push(idx);
126        }
127        Ok(indices)
128    }
129
130    /// Returns the value at the given event and parameter indices.
131    /// Returns `None` if either index is out of bounds.
132    #[inline]
133    pub fn value_at(&self, event_idx: usize, param_idx: usize) -> Option<f64> {
134        if event_idx >= self.n_events || param_idx >= self.n_params {
135            return None;
136        }
137        // SAFETY: bounds checked above.
138        Some(unsafe { self.view.get_unchecked(event_idx, param_idx) })
139    }
140
141    pub fn param_names(&self) -> &[ParameterName] {
142        &self.param_names
143    }
144}
145
146pub struct EventMatrix {
147    pub n_events: usize,
148    pub n_params: usize,
149    data: Vec<f64>,
150    param_names: Vec<ParameterName>,
151    param_index: HashMap<ParameterName, usize>,
152}
153
154impl EventMatrix {
155    pub fn new(
156        n_events: usize,
157        n_params: usize,
158        data: Vec<f64>,
159        param_names: Vec<ParameterName>,
160    ) -> Result<Self, FlowGateError> {
161        if data.len() != n_events.saturating_mul(n_params) {
162            return Err(FlowGateError::InvalidGate(format!(
163                "EventMatrix data length {} does not match n_events*n_params {}",
164                data.len(),
165                n_events.saturating_mul(n_params)
166            )));
167        }
168        if param_names.len() != n_params {
169            return Err(FlowGateError::InvalidGate(format!(
170                "EventMatrix param_names length {} does not match n_params {}",
171                param_names.len(),
172                n_params
173            )));
174        }
175        let mut param_index = HashMap::with_capacity(param_names.len());
176        for (i, name) in param_names.iter().enumerate() {
177            param_index.insert(name.clone(), i);
178        }
179        Ok(Self {
180            n_events,
181            n_params,
182            data,
183            param_names,
184            param_index,
185        })
186    }
187
188    pub fn from_columns(
189        columns: Vec<Vec<f64>>,
190        param_names: Vec<ParameterName>,
191    ) -> Result<Self, FlowGateError> {
192        let n_params = columns.len();
193        let n_events = columns.first().map_or(0, Vec::len);
194        if columns.iter().any(|c| c.len() != n_events) {
195            return Err(FlowGateError::InvalidGate(
196                "All EventMatrix columns must have identical length".to_string(),
197            ));
198        }
199        let mut data = Vec::with_capacity(n_events.saturating_mul(n_params));
200        for col in columns {
201            data.extend_from_slice(&col);
202        }
203        Self::new(n_events, n_params, data, param_names)
204    }
205
206    pub fn from_view<'a>(
207        view: MatrixView<'a>,
208        param_names: Vec<ParameterName>,
209    ) -> Result<EventMatrixView<'a>, FlowGateError> {
210        if param_names.len() != view.n_cols {
211            return Err(FlowGateError::DimensionMismatch(
212                param_names.len(),
213                view.n_cols,
214            ));
215        }
216        let mut param_index = HashMap::with_capacity(param_names.len());
217        for (i, name) in param_names.iter().enumerate() {
218            param_index.insert(name.clone(), i);
219        }
220        Ok(EventMatrixView {
221            view,
222            n_events: view.n_rows,
223            n_params: view.n_cols,
224            param_names,
225            param_index,
226        })
227    }
228
229    pub fn data(&self) -> &[f64] {
230        &self.data
231    }
232
233    pub fn param_names(&self) -> &[ParameterName] {
234        &self.param_names
235    }
236
237    pub fn column(&self, column_index: usize) -> Option<&[f64]> {
238        if column_index >= self.n_params {
239            return None;
240        }
241        let start = column_index * self.n_events;
242        let end = start + self.n_events;
243        Some(&self.data[start..end])
244    }
245
246    pub fn project(&self, names: &[ParameterName]) -> Result<ProjectedMatrix<'_>, FlowGateError> {
247        let mut columns = SmallVec::<[&[f64]; 4]>::with_capacity(names.len());
248        for name in names {
249            let Some(&idx) = self.param_index.get(name) else {
250                return Err(FlowGateError::UnknownParameter(name.clone()));
251            };
252            let start = idx * self.n_events;
253            let end = start + self.n_events;
254            columns.push(&self.data[start..end]);
255        }
256        Ok(ProjectedMatrix {
257            n_events: self.n_events,
258            n_cols: names.len(),
259            columns,
260        })
261    }
262
263    /// Deviation approved by user: transform dispatch uses `TransformKind` instead of `&dyn Transform`
264    /// because `Transform: Clone` is not object-safe.
265    pub fn apply_transforms_inplace(&mut self, transforms: &[(usize, TransformKind)]) {
266        let transform_map: HashMap<usize, TransformKind> = transforms.iter().copied().collect();
267        self.data
268            .par_chunks_mut(self.n_events.max(1))
269            .enumerate()
270            .for_each(|(col_idx, col)| {
271                if let Some(transform) = transform_map.get(&col_idx) {
272                    for value in col {
273                        *value = transform.apply(*value);
274                    }
275                }
276            });
277    }
278
279    pub fn events(&self) -> impl Iterator<Item = SmallVec<[f64; 8]>> + '_ {
280        (0..self.n_events).map(|event_idx| {
281            let mut row = SmallVec::<[f64; 8]>::with_capacity(self.n_params);
282            for col_idx in 0..self.n_params {
283                let offset = col_idx * self.n_events + event_idx;
284                row.push(self.data[offset]);
285            }
286            row
287        })
288    }
289}
290
291pub struct ProjectedMatrix<'a> {
292    pub(crate) n_events: usize,
293    pub(crate) n_cols: usize,
294    pub(crate) columns: SmallVec<[&'a [f64]; 4]>,
295}
296
297impl<'a> ProjectedMatrix<'a> {
298    pub fn n_events(&self) -> usize {
299        self.n_events
300    }
301
302    pub fn n_cols(&self) -> usize {
303        self.n_cols
304    }
305
306    pub fn columns(&self) -> &[&'a [f64]] {
307        &self.columns
308    }
309
310    pub fn events(&'a self) -> impl Iterator<Item = SmallVec<[f64; 4]>> + 'a {
311        (0..self.n_events).map(|event_idx| {
312            let mut values = SmallVec::<[f64; 4]>::with_capacity(self.n_cols);
313            for col in &self.columns {
314                values.push(col[event_idx]);
315            }
316            values
317        })
318    }
319}