1use std::collections::HashMap;
2use std::marker::PhantomData;
3
4use rayon::prelude::*;
5use smallvec::SmallVec;
6
7use crate::error::FlowGateError;
8use crate::traits::{ParameterName, Transform};
9use crate::transform::TransformKind;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MatrixLayout {
13 RowMajor,
14 ColMajor,
15}
16
17#[derive(Debug, Clone, Copy)]
25pub struct MatrixView<'a> {
26 pub ptr: *const f64,
27 pub n_rows: usize,
28 pub n_cols: usize,
29 pub layout: MatrixLayout,
30 _lifetime: PhantomData<&'a f64>,
31}
32
33unsafe impl<'a> Send for MatrixView<'a> {}
39unsafe impl<'a> Sync for MatrixView<'a> {}
42
43impl<'a> MatrixView<'a> {
44 pub unsafe fn from_raw(
52 ptr: *const f64,
53 n_rows: usize,
54 n_cols: usize,
55 layout: MatrixLayout,
56 ) -> Self {
57 Self {
58 ptr,
59 n_rows,
60 n_cols,
61 layout,
62 _lifetime: PhantomData,
63 }
64 }
65
66 #[inline]
71 pub unsafe fn get_unchecked(&self, row: usize, col: usize) -> f64 {
72 match self.layout {
73 MatrixLayout::RowMajor => *self.ptr.add(row * self.n_cols + col),
74 MatrixLayout::ColMajor => *self.ptr.add(col * self.n_rows + row),
75 }
76 }
77
78 pub fn column(&self, col: usize) -> ColumnIter<'a, '_> {
79 ColumnIter {
80 view: self,
81 col,
82 row: 0,
83 }
84 }
85}
86
87pub struct ColumnIter<'a, 'v> {
88 view: &'v MatrixView<'a>,
89 col: usize,
90 row: usize,
91}
92
93impl<'a, 'v> Iterator for ColumnIter<'a, 'v> {
94 type Item = f64;
95
96 fn next(&mut self) -> Option<Self::Item> {
97 if self.row >= self.view.n_rows {
98 return None;
99 }
100 let row = self.row;
101 self.row += 1;
102 Some(unsafe { self.view.get_unchecked(row, self.col) })
104 }
105}
106
107pub struct EventMatrixView<'a> {
108 view: MatrixView<'a>,
109 pub n_events: usize,
110 pub n_params: usize,
111 param_names: Vec<ParameterName>,
112 param_index: HashMap<ParameterName, usize>,
113}
114
115impl<'a> EventMatrixView<'a> {
116 pub fn project_indices(
117 &self,
118 names: &[ParameterName],
119 ) -> Result<SmallVec<[usize; 8]>, FlowGateError> {
120 let mut indices = SmallVec::<[usize; 8]>::with_capacity(names.len());
121 for name in names {
122 let Some(&idx) = self.param_index.get(name) else {
123 return Err(FlowGateError::UnknownParameter(name.clone()));
124 };
125 indices.push(idx);
126 }
127 Ok(indices)
128 }
129
130 #[inline]
133 pub fn value_at(&self, event_idx: usize, param_idx: usize) -> Option<f64> {
134 if event_idx >= self.n_events || param_idx >= self.n_params {
135 return None;
136 }
137 Some(unsafe { self.view.get_unchecked(event_idx, param_idx) })
139 }
140
141 pub fn param_names(&self) -> &[ParameterName] {
142 &self.param_names
143 }
144}
145
146pub struct EventMatrix {
147 pub n_events: usize,
148 pub n_params: usize,
149 data: Vec<f64>,
150 param_names: Vec<ParameterName>,
151 param_index: HashMap<ParameterName, usize>,
152}
153
154impl EventMatrix {
155 pub fn new(
156 n_events: usize,
157 n_params: usize,
158 data: Vec<f64>,
159 param_names: Vec<ParameterName>,
160 ) -> Result<Self, FlowGateError> {
161 if data.len() != n_events.saturating_mul(n_params) {
162 return Err(FlowGateError::InvalidGate(format!(
163 "EventMatrix data length {} does not match n_events*n_params {}",
164 data.len(),
165 n_events.saturating_mul(n_params)
166 )));
167 }
168 if param_names.len() != n_params {
169 return Err(FlowGateError::InvalidGate(format!(
170 "EventMatrix param_names length {} does not match n_params {}",
171 param_names.len(),
172 n_params
173 )));
174 }
175 let mut param_index = HashMap::with_capacity(param_names.len());
176 for (i, name) in param_names.iter().enumerate() {
177 param_index.insert(name.clone(), i);
178 }
179 Ok(Self {
180 n_events,
181 n_params,
182 data,
183 param_names,
184 param_index,
185 })
186 }
187
188 pub fn from_columns(
189 columns: Vec<Vec<f64>>,
190 param_names: Vec<ParameterName>,
191 ) -> Result<Self, FlowGateError> {
192 let n_params = columns.len();
193 let n_events = columns.first().map_or(0, Vec::len);
194 if columns.iter().any(|c| c.len() != n_events) {
195 return Err(FlowGateError::InvalidGate(
196 "All EventMatrix columns must have identical length".to_string(),
197 ));
198 }
199 let mut data = Vec::with_capacity(n_events.saturating_mul(n_params));
200 for col in columns {
201 data.extend_from_slice(&col);
202 }
203 Self::new(n_events, n_params, data, param_names)
204 }
205
206 pub fn from_view<'a>(
207 view: MatrixView<'a>,
208 param_names: Vec<ParameterName>,
209 ) -> Result<EventMatrixView<'a>, FlowGateError> {
210 if param_names.len() != view.n_cols {
211 return Err(FlowGateError::DimensionMismatch(
212 param_names.len(),
213 view.n_cols,
214 ));
215 }
216 let mut param_index = HashMap::with_capacity(param_names.len());
217 for (i, name) in param_names.iter().enumerate() {
218 param_index.insert(name.clone(), i);
219 }
220 Ok(EventMatrixView {
221 view,
222 n_events: view.n_rows,
223 n_params: view.n_cols,
224 param_names,
225 param_index,
226 })
227 }
228
229 pub fn data(&self) -> &[f64] {
230 &self.data
231 }
232
233 pub fn param_names(&self) -> &[ParameterName] {
234 &self.param_names
235 }
236
237 pub fn column(&self, column_index: usize) -> Option<&[f64]> {
238 if column_index >= self.n_params {
239 return None;
240 }
241 let start = column_index * self.n_events;
242 let end = start + self.n_events;
243 Some(&self.data[start..end])
244 }
245
246 pub fn project(&self, names: &[ParameterName]) -> Result<ProjectedMatrix<'_>, FlowGateError> {
247 let mut columns = SmallVec::<[&[f64]; 4]>::with_capacity(names.len());
248 for name in names {
249 let Some(&idx) = self.param_index.get(name) else {
250 return Err(FlowGateError::UnknownParameter(name.clone()));
251 };
252 let start = idx * self.n_events;
253 let end = start + self.n_events;
254 columns.push(&self.data[start..end]);
255 }
256 Ok(ProjectedMatrix {
257 n_events: self.n_events,
258 n_cols: names.len(),
259 columns,
260 })
261 }
262
263 pub fn apply_transforms_inplace(&mut self, transforms: &[(usize, TransformKind)]) {
266 let transform_map: HashMap<usize, TransformKind> = transforms.iter().copied().collect();
267 self.data
268 .par_chunks_mut(self.n_events.max(1))
269 .enumerate()
270 .for_each(|(col_idx, col)| {
271 if let Some(transform) = transform_map.get(&col_idx) {
272 for value in col {
273 *value = transform.apply(*value);
274 }
275 }
276 });
277 }
278
279 pub fn events(&self) -> impl Iterator<Item = SmallVec<[f64; 8]>> + '_ {
280 (0..self.n_events).map(|event_idx| {
281 let mut row = SmallVec::<[f64; 8]>::with_capacity(self.n_params);
282 for col_idx in 0..self.n_params {
283 let offset = col_idx * self.n_events + event_idx;
284 row.push(self.data[offset]);
285 }
286 row
287 })
288 }
289}
290
291pub struct ProjectedMatrix<'a> {
292 pub(crate) n_events: usize,
293 pub(crate) n_cols: usize,
294 pub(crate) columns: SmallVec<[&'a [f64]; 4]>,
295}
296
297impl<'a> ProjectedMatrix<'a> {
298 pub fn n_events(&self) -> usize {
299 self.n_events
300 }
301
302 pub fn n_cols(&self) -> usize {
303 self.n_cols
304 }
305
306 pub fn columns(&self) -> &[&'a [f64]] {
307 &self.columns
308 }
309
310 pub fn events(&'a self) -> impl Iterator<Item = SmallVec<[f64; 4]>> + 'a {
311 (0..self.n_events).map(|event_idx| {
312 let mut values = SmallVec::<[f64; 4]>::with_capacity(self.n_cols);
313 for col in &self.columns {
314 values.push(col[event_idx]);
315 }
316 values
317 })
318 }
319}