Skip to main content

apple_accelerate/
lapack.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3
4fn checked_square_length(dimension: usize, actual: usize) -> Result<()> {
5    let expected = dimension
6        .checked_mul(dimension)
7        .ok_or(Error::OperationFailed("matrix dimensions overflowed"))?;
8    if actual == expected {
9        Ok(())
10    } else {
11        Err(Error::InvalidLength { expected, actual })
12    }
13}
14
15fn i32_len(value: usize) -> Result<i32> {
16    i32::try_from(value).map_err(|_| Error::OperationFailed("dimension exceeds i32"))
17}
18
19fn lapack_result(info: i32) -> Result<()> {
20    if info == 0 {
21        Ok(())
22    } else {
23        Err(Error::LapackInfo(info))
24    }
25}
26
27/// Compact LU factorization in column-major layout.
28#[derive(Debug, Clone, PartialEq)]
29pub struct LuDecompositionF32 {
30    factors: Vec<f32>,
31    pivots: Vec<i32>,
32    dimension: usize,
33}
34
35impl LuDecompositionF32 {
36    #[must_use]
37    pub fn factors(&self) -> &[f32] {
38        &self.factors
39    }
40
41    #[must_use]
42    pub fn pivots(&self) -> &[i32] {
43        &self.pivots
44    }
45
46    #[must_use]
47    pub const fn dimension(&self) -> usize {
48        self.dimension
49    }
50}
51
52/// Compute an LU factorization of a column-major square matrix.
53pub fn lu_decompose_f32(matrix: &[f32], dimension: usize) -> Result<LuDecompositionF32> {
54    checked_square_length(dimension, matrix.len())?;
55    let dimension_i32 = i32_len(dimension)?;
56
57    let mut factors = matrix.to_vec();
58    let mut pivots = vec![0_i32; dimension];
59    // SAFETY: Buffers are valid for `dimension * dimension` matrix entries and `dimension` pivots.
60    let info = unsafe {
61        bridge::acc_lapack_sgetrf(factors.as_mut_ptr(), dimension_i32, pivots.as_mut_ptr())
62    };
63    lapack_result(info)?;
64
65    Ok(LuDecompositionF32 {
66        factors,
67        pivots,
68        dimension,
69    })
70}
71
72/// Solve `A * X = B` for a column-major square matrix `A` and one-or-more right-hand sides `B`.
73pub fn solve_linear_system_f32(matrix: &[f32], dimension: usize, rhs: &[f32]) -> Result<Vec<f32>> {
74    checked_square_length(dimension, matrix.len())?;
75    if rhs.is_empty() {
76        return Err(Error::InvalidLength {
77            expected: dimension,
78            actual: 0,
79        });
80    }
81    if rhs.len() % dimension != 0 {
82        return Err(Error::InvalidLength {
83            expected: dimension,
84            actual: rhs.len(),
85        });
86    }
87
88    let dimension_i32 = i32_len(dimension)?;
89    let rhs_count = i32_len(rhs.len() / dimension)?;
90    let mut factors = matrix.to_vec();
91    let mut solution = rhs.to_vec();
92    let mut pivots = vec![0_i32; dimension];
93
94    // SAFETY: Buffers are valid for LAPACK's in-place solve routine.
95    let info = unsafe {
96        bridge::acc_lapack_sgesv(
97            factors.as_mut_ptr(),
98            dimension_i32,
99            solution.as_mut_ptr(),
100            rhs_count,
101            pivots.as_mut_ptr(),
102        )
103    };
104    lapack_result(info)?;
105    Ok(solution)
106}