apple_accelerate/
lapack.rs1use crate::bridge;
2use crate::error::{Error, Result};
3
4fn checked_square_length(dimension: usize, actual: usize) -> Result<()> {
5 let expected = dimension
6 .checked_mul(dimension)
7 .ok_or(Error::OperationFailed("matrix dimensions overflowed"))?;
8 if actual == expected {
9 Ok(())
10 } else {
11 Err(Error::InvalidLength { expected, actual })
12 }
13}
14
15fn i32_len(value: usize) -> Result<i32> {
16 i32::try_from(value).map_err(|_| Error::OperationFailed("dimension exceeds i32"))
17}
18
19fn lapack_result(info: i32) -> Result<()> {
20 if info == 0 {
21 Ok(())
22 } else {
23 Err(Error::LapackInfo(info))
24 }
25}
26
27#[derive(Debug, Clone, PartialEq)]
29pub struct LuDecompositionF32 {
30 factors: Vec<f32>,
31 pivots: Vec<i32>,
32 dimension: usize,
33}
34
35impl LuDecompositionF32 {
36 #[must_use]
38 pub fn factors(&self) -> &[f32] {
39 &self.factors
40 }
41
42 #[must_use]
44 pub fn pivots(&self) -> &[i32] {
45 &self.pivots
46 }
47
48 #[must_use]
50 pub const fn dimension(&self) -> usize {
51 self.dimension
52 }
53}
54
55pub fn lu_decompose_f32(matrix: &[f32], dimension: usize) -> Result<LuDecompositionF32> {
57 checked_square_length(dimension, matrix.len())?;
58 let dimension_i32 = i32_len(dimension)?;
59
60 let mut factors = matrix.to_vec();
61 let mut pivots = vec![0_i32; dimension];
62 let info = unsafe {
64 bridge::acc_lapack_sgetrf(factors.as_mut_ptr(), dimension_i32, pivots.as_mut_ptr())
65 };
66 lapack_result(info)?;
67
68 Ok(LuDecompositionF32 {
69 factors,
70 pivots,
71 dimension,
72 })
73}
74
75pub fn solve_linear_system_f32(matrix: &[f32], dimension: usize, rhs: &[f32]) -> Result<Vec<f32>> {
77 checked_square_length(dimension, matrix.len())?;
78 if rhs.is_empty() {
79 return Err(Error::InvalidLength {
80 expected: dimension,
81 actual: 0,
82 });
83 }
84 if rhs.len() % dimension != 0 {
85 return Err(Error::InvalidLength {
86 expected: dimension,
87 actual: rhs.len(),
88 });
89 }
90
91 let dimension_i32 = i32_len(dimension)?;
92 let rhs_count = i32_len(rhs.len() / dimension)?;
93 let mut factors = matrix.to_vec();
94 let mut solution = rhs.to_vec();
95 let mut pivots = vec![0_i32; dimension];
96
97 let info = unsafe {
99 bridge::acc_lapack_sgesv(
100 factors.as_mut_ptr(),
101 dimension_i32,
102 solution.as_mut_ptr(),
103 rhs_count,
104 pivots.as_mut_ptr(),
105 )
106 };
107 lapack_result(info)?;
108 Ok(solution)
109}