Skip to main content

flow_fcs/
matrix.rs

1//! Matrix operations for flow cytometry compensation
2//!
3//! Provides CPU-based matrix operations for compensation calculations.
4
5use anyhow::Result;
6use faer::{Mat, MatRef};
7
8/// Matrix operations for compensation
9pub struct MatrixOps;
10
11impl MatrixOps {
12    /// Invert matrix on CPU.
13    /// Uses faer (pure Rust) by default, or ndarray-linalg with system BLAS when `blas` feature is enabled.
14    pub fn invert_matrix(matrix: MatRef<'_, f32>) -> Result<Mat<f32>> {
15        #[cfg(feature = "blas")]
16        {
17            use faer_ext::IntoNdarray;
18            use ndarray_linalg::Inverse;
19
20            let a_ndarray = matrix.into_ndarray().to_owned();
21            let inv_ndarray = a_ndarray
22                .inv()
23                .map_err(|e| anyhow::anyhow!("BLAS inverse failed: {}", e))?;
24            Ok(Mat::from_fn(
25                inv_ndarray.nrows(),
26                inv_ndarray.ncols(),
27                |i, j| inv_ndarray[[i, j]],
28            ))
29        }
30
31        #[cfg(not(feature = "blas"))]
32        {
33            use faer::linalg::solvers::{DenseSolveCore, PartialPivLu};
34
35            let lu = PartialPivLu::new(matrix);
36            Ok(lu.inverse())
37        }
38    }
39
40    /// Batch matrix-vector multiplication on CPU
41    /// Input: matrix [n×n], channel_data [n_channels × n_events]
42    /// Output: compensated_data [n_channels × n_events]
43    pub fn batch_matvec(
44        matrix: MatRef<'_, f32>,
45        channel_data: &[Vec<f32>],
46    ) -> Result<Vec<Vec<f32>>> {
47        let n_channels = channel_data.len();
48        let n_events = channel_data.first().map(|v| v.len()).unwrap_or(0);
49
50        if n_events == 0 {
51            return Ok(vec![]);
52        }
53
54        if matrix.nrows() != n_channels || matrix.ncols() != n_channels {
55            return Err(anyhow::anyhow!(
56                "Matrix dimensions ({}, {}) don't match channel count ({})",
57                matrix.nrows(),
58                matrix.ncols(),
59                n_channels
60            ));
61        }
62
63        // Build data matrix: [n_channels × n_events]
64        let data_matrix =
65            Mat::from_fn(n_channels, n_events, |i, j| channel_data[i][j]);
66
67        // Result: matrix @ data_matrix -> [n_channels × n_events]
68        let mut result = Mat::zeros(n_channels, n_events);
69        faer::linalg::matmul::matmul(
70            result.as_mut(),
71            faer::Accum::Replace,
72            matrix,
73            data_matrix.as_ref(),
74            1.0_f32,
75            faer::Par::rayon(0),
76        );
77
78        // Convert back to Vec<Vec<f32>>
79        let mut out = Vec::with_capacity(n_channels);
80        for i in 0..n_channels {
81            let channel_result: Vec<f32> =
82                (0..n_events).map(|j| result[(i, j)]).collect();
83            out.push(channel_result);
84        }
85
86        Ok(out)
87    }
88
89    /// Compensate parameters on CPU
90    pub fn compensate_parameters(
91        comp_matrix: MatRef<'_, f32>,
92        channel_data: &[Vec<f32>],
93    ) -> Result<Vec<Vec<f32>>> {
94        let comp_inv = Self::invert_matrix(comp_matrix)?;
95        Self::batch_matvec(comp_inv.as_ref(), channel_data)
96    }
97}