apple_accelerate/
lapack.rs1use crate::bridge;
2use crate::error::{Error, Result};
3
4fn checked_square_length(dimension: usize, actual: usize) -> Result<()> {
5 let expected = dimension
6 .checked_mul(dimension)
7 .ok_or(Error::OperationFailed("matrix dimensions overflowed"))?;
8 if actual == expected {
9 Ok(())
10 } else {
11 Err(Error::InvalidLength { expected, actual })
12 }
13}
14
15fn i32_len(value: usize) -> Result<i32> {
16 i32::try_from(value).map_err(|_| Error::OperationFailed("dimension exceeds i32"))
17}
18
19fn lapack_result(info: i32) -> Result<()> {
20 if info == 0 {
21 Ok(())
22 } else {
23 Err(Error::LapackInfo(info))
24 }
25}
26
27#[derive(Debug, Clone, PartialEq)]
29pub struct LuDecompositionF32 {
30 factors: Vec<f32>,
31 pivots: Vec<i32>,
32 dimension: usize,
33}
34
35impl LuDecompositionF32 {
36 #[must_use]
37 pub fn factors(&self) -> &[f32] {
38 &self.factors
39 }
40
41 #[must_use]
42 pub fn pivots(&self) -> &[i32] {
43 &self.pivots
44 }
45
46 #[must_use]
47 pub const fn dimension(&self) -> usize {
48 self.dimension
49 }
50}
51
52pub fn lu_decompose_f32(matrix: &[f32], dimension: usize) -> Result<LuDecompositionF32> {
54 checked_square_length(dimension, matrix.len())?;
55 let dimension_i32 = i32_len(dimension)?;
56
57 let mut factors = matrix.to_vec();
58 let mut pivots = vec![0_i32; dimension];
59 let info = unsafe {
61 bridge::acc_lapack_sgetrf(factors.as_mut_ptr(), dimension_i32, pivots.as_mut_ptr())
62 };
63 lapack_result(info)?;
64
65 Ok(LuDecompositionF32 {
66 factors,
67 pivots,
68 dimension,
69 })
70}
71
72pub fn solve_linear_system_f32(matrix: &[f32], dimension: usize, rhs: &[f32]) -> Result<Vec<f32>> {
74 checked_square_length(dimension, matrix.len())?;
75 if rhs.is_empty() {
76 return Err(Error::InvalidLength {
77 expected: dimension,
78 actual: 0,
79 });
80 }
81 if rhs.len() % dimension != 0 {
82 return Err(Error::InvalidLength {
83 expected: dimension,
84 actual: rhs.len(),
85 });
86 }
87
88 let dimension_i32 = i32_len(dimension)?;
89 let rhs_count = i32_len(rhs.len() / dimension)?;
90 let mut factors = matrix.to_vec();
91 let mut solution = rhs.to_vec();
92 let mut pivots = vec![0_i32; dimension];
93
94 let info = unsafe {
96 bridge::acc_lapack_sgesv(
97 factors.as_mut_ptr(),
98 dimension_i32,
99 solution.as_mut_ptr(),
100 rhs_count,
101 pivots.as_mut_ptr(),
102 )
103 };
104 lapack_result(info)?;
105 Ok(solution)
106}