1use anyhow::Result;
6use faer::{Mat, MatRef};
7
8pub struct MatrixOps;
10
11impl MatrixOps {
12 pub fn invert_matrix(matrix: MatRef<'_, f32>) -> Result<Mat<f32>> {
15 #[cfg(feature = "blas")]
16 {
17 use faer_ext::IntoNdarray;
18 use ndarray_linalg::Inverse;
19
20 let a_ndarray = matrix.into_ndarray().to_owned();
21 let inv_ndarray = a_ndarray
22 .inv()
23 .map_err(|e| anyhow::anyhow!("BLAS inverse failed: {}", e))?;
24 Ok(Mat::from_fn(
25 inv_ndarray.nrows(),
26 inv_ndarray.ncols(),
27 |i, j| inv_ndarray[[i, j]],
28 ))
29 }
30
31 #[cfg(not(feature = "blas"))]
32 {
33 use faer::linalg::solvers::{DenseSolveCore, PartialPivLu};
34
35 let lu = PartialPivLu::new(matrix);
36 Ok(lu.inverse())
37 }
38 }
39
40 pub fn batch_matvec(
44 matrix: MatRef<'_, f32>,
45 channel_data: &[Vec<f32>],
46 ) -> Result<Vec<Vec<f32>>> {
47 let n_channels = channel_data.len();
48 let n_events = channel_data.first().map(|v| v.len()).unwrap_or(0);
49
50 if n_events == 0 {
51 return Ok(vec![]);
52 }
53
54 if matrix.nrows() != n_channels || matrix.ncols() != n_channels {
55 return Err(anyhow::anyhow!(
56 "Matrix dimensions ({}, {}) don't match channel count ({})",
57 matrix.nrows(),
58 matrix.ncols(),
59 n_channels
60 ));
61 }
62
63 let data_matrix =
65 Mat::from_fn(n_channels, n_events, |i, j| channel_data[i][j]);
66
67 let mut result = Mat::zeros(n_channels, n_events);
69 faer::linalg::matmul::matmul(
70 result.as_mut(),
71 faer::Accum::Replace,
72 matrix,
73 data_matrix.as_ref(),
74 1.0_f32,
75 faer::Par::rayon(0),
76 );
77
78 let mut out = Vec::with_capacity(n_channels);
80 for i in 0..n_channels {
81 let channel_result: Vec<f32> =
82 (0..n_events).map(|j| result[(i, j)]).collect();
83 out.push(channel_result);
84 }
85
86 Ok(out)
87 }
88
89 pub fn compensate_parameters(
91 comp_matrix: MatRef<'_, f32>,
92 channel_data: &[Vec<f32>],
93 ) -> Result<Vec<Vec<f32>>> {
94 let comp_inv = Self::invert_matrix(comp_matrix)?;
95 Self::batch_matvec(comp_inv.as_ref(), channel_data)
96 }
97}