Skip to main content

apple_accelerate/
lapack.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3
4fn checked_square_length(dimension: usize, actual: usize) -> Result<()> {
5    let expected = dimension
6        .checked_mul(dimension)
7        .ok_or(Error::OperationFailed("matrix dimensions overflowed"))?;
8    if actual == expected {
9        Ok(())
10    } else {
11        Err(Error::InvalidLength { expected, actual })
12    }
13}
14
15fn i32_len(value: usize) -> Result<i32> {
16    i32::try_from(value).map_err(|_| Error::OperationFailed("dimension exceeds i32"))
17}
18
19fn lapack_result(info: i32) -> Result<()> {
20    if info == 0 {
21        Ok(())
22    } else {
23        Err(Error::LapackInfo(info))
24    }
25}
26
27/// Compact LU factorization in column-major layout.
28#[derive(Debug, Clone, PartialEq)]
29pub struct LuDecompositionF32 {
30    factors: Vec<f32>,
31    pivots: Vec<i32>,
32    dimension: usize,
33}
34
35impl LuDecompositionF32 {
36    /// Returns the packed LU factors produced by `sgetrf_`.
37    #[must_use]
38    pub fn factors(&self) -> &[f32] {
39        &self.factors
40    }
41
42    /// Returns the pivot indices produced by `sgetrf_`.
43    #[must_use]
44    pub fn pivots(&self) -> &[i32] {
45        &self.pivots
46    }
47
48    /// Returns the matrix dimension used with `sgetrf_`.
49    #[must_use]
50    pub const fn dimension(&self) -> usize {
51        self.dimension
52    }
53}
54
55/// Compute an LU factorization of a column-major square matrix.
56pub fn lu_decompose_f32(matrix: &[f32], dimension: usize) -> Result<LuDecompositionF32> {
57    checked_square_length(dimension, matrix.len())?;
58    let dimension_i32 = i32_len(dimension)?;
59
60    let mut factors = matrix.to_vec();
61    let mut pivots = vec![0_i32; dimension];
62    // SAFETY: Buffers are valid for `dimension * dimension` matrix entries and `dimension` pivots.
63    let info = unsafe {
64        bridge::acc_lapack_sgetrf(factors.as_mut_ptr(), dimension_i32, pivots.as_mut_ptr())
65    };
66    lapack_result(info)?;
67
68    Ok(LuDecompositionF32 {
69        factors,
70        pivots,
71        dimension,
72    })
73}
74
75/// Solve `A * X = B` for a column-major square matrix `A` and one-or-more right-hand sides `B`.
76pub fn solve_linear_system_f32(matrix: &[f32], dimension: usize, rhs: &[f32]) -> Result<Vec<f32>> {
77    checked_square_length(dimension, matrix.len())?;
78    if rhs.is_empty() {
79        return Err(Error::InvalidLength {
80            expected: dimension,
81            actual: 0,
82        });
83    }
84    if rhs.len() % dimension != 0 {
85        return Err(Error::InvalidLength {
86            expected: dimension,
87            actual: rhs.len(),
88        });
89    }
90
91    let dimension_i32 = i32_len(dimension)?;
92    let rhs_count = i32_len(rhs.len() / dimension)?;
93    let mut factors = matrix.to_vec();
94    let mut solution = rhs.to_vec();
95    let mut pivots = vec![0_i32; dimension];
96
97    // SAFETY: Buffers are valid for LAPACK's in-place solve routine.
98    let info = unsafe {
99        bridge::acc_lapack_sgesv(
100            factors.as_mut_ptr(),
101            dimension_i32,
102            solution.as_mut_ptr(),
103            rhs_count,
104            pivots.as_mut_ptr(),
105        )
106    };
107    lapack_result(info)?;
108    Ok(solution)
109}