mdarray_linalg_faer/lu/
context.rs

1// LU Decomposition with partial pivoting:
2//     P * A = L * U
3// where:
4//     - A is m × n         (input matrix)
5//     - P is m × m        (permutation matrix)
6//     - L is m × m        (lower triangular with ones on diagonal)
7//     - U is m × n         (upper triangular/trapezoidal matrix)
8
9use dyn_stack::{MemBuffer, MemStack};
10
11use super::simple::lu_faer;
12use faer_traits::ComplexField;
13use mdarray::{DSlice, DTensor, Layout, tensor};
14use mdarray_linalg::lu::{InvError, InvResult, LU};
15use num_complex::ComplexFloat;
16
17use crate::{Faer, into_faer_mut, into_mdarray};
18
19impl<T> LU<T> for Faer
20where
21    T: ComplexFloat
22        + ComplexField
23        + Default
24        + std::convert::From<<T as num_complex::ComplexFloat>::Real>
25        + 'static,
26{
27    /// Computes LU decomposition with new allocated matrices: L, U, P (permutation matrix)
28    fn lu<L: Layout>(
29        &self,
30        a: &mut DSlice<T, 2, L>,
31    ) -> (DTensor<T, 2>, DTensor<T, 2>, DTensor<T, 2>) {
32        let (m, n) = *a.shape();
33        let min_mn = m.min(n);
34        let mut l_mda = tensor![[T::default(); min_mn]; m ];
35        let mut u_mda = tensor![[T::default(); n ]; min_mn];
36        let mut p_mda = tensor![[T::default(); m]; m];
37
38        lu_faer(a, &mut l_mda, &mut u_mda, &mut p_mda);
39
40        (l_mda, u_mda, p_mda)
41    }
42
43    /// Computes LU decomposition overwriting existing matrices
44    fn lu_overwrite<L: Layout, Ll: Layout, Lu: Layout, Lp: Layout>(
45        &self,
46        a: &mut DSlice<T, 2, L>,
47        l: &mut DSlice<T, 2, Ll>,
48        u: &mut DSlice<T, 2, Lu>,
49        p: &mut DSlice<T, 2, Lp>,
50    ) {
51        lu_faer::<T, L, Ll, Lu, Lp>(a, l, u, p);
52    }
53
54    /// Computes inverse with new allocated matrix
55    fn inv<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> InvResult<T> {
56        let (m, n) = *a.shape();
57
58        if m != n {
59            return Err(InvError::NotSquare {
60                rows: m as i32,
61                cols: n as i32,
62            });
63        }
64
65        let par = faer::get_global_parallelism();
66        let mut a_faer = into_faer_mut(a);
67
68        let mut row_perm_fwd = vec![0usize; m];
69        let mut row_perm_bwd = vec![0usize; m];
70
71        faer::linalg::lu::partial_pivoting::factor::lu_in_place(
72            a_faer.as_mut(),
73            &mut row_perm_fwd,
74            &mut row_perm_bwd,
75            par,
76            MemStack::new(&mut MemBuffer::new(
77                faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(
78                    m,
79                    n,
80                    par,
81                    faer::prelude::default(),
82                ),
83            )),
84            faer::prelude::default(),
85        );
86
87        let l_mat = a_faer.as_ref();
88        let u_mat = a_faer.as_ref();
89
90        let perm = unsafe {
91            faer::perm::Perm::new_unchecked(
92                row_perm_fwd.into_boxed_slice(),
93                row_perm_bwd.into_boxed_slice(),
94            )
95        };
96
97        let mut inv_mat = faer::Mat::<T>::zeros(m, n);
98
99        faer::linalg::lu::partial_pivoting::inverse::inverse(
100            inv_mat.as_mut(),
101            l_mat,
102            u_mat,
103            perm.as_ref(),
104            par,
105            MemStack::new(&mut MemBuffer::new(
106                faer::linalg::lu::partial_pivoting::inverse::inverse_scratch::<usize, T>(m, par),
107            )),
108        );
109        Ok(into_mdarray(inv_mat))
110    }
111
112    /// Computes inverse overwriting the input matrix
113    fn inv_overwrite<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<(), InvError> {
114        let (m, n) = *a.shape();
115
116        if m != n {
117            return Err(InvError::NotSquare {
118                rows: m as i32,
119                cols: n as i32,
120            });
121        }
122
123        let par = faer::get_global_parallelism();
124        let mut a_faer = into_faer_mut(a);
125
126        let mut row_perm_fwd = vec![0usize; m];
127        let mut row_perm_bwd = vec![0usize; m];
128
129        faer::linalg::lu::partial_pivoting::factor::lu_in_place(
130            a_faer.as_mut(),
131            &mut row_perm_fwd,
132            &mut row_perm_bwd,
133            par,
134            MemStack::new(&mut MemBuffer::new(
135                faer::linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(
136                    m,
137                    n,
138                    par,
139                    faer::prelude::default(),
140                ),
141            )),
142            faer::prelude::default(),
143        );
144
145        let l_mat = a_faer.as_ref();
146        let u_mat = a_faer.as_ref();
147
148        let perm = unsafe {
149            faer::perm::Perm::new_unchecked(
150                row_perm_fwd.into_boxed_slice(),
151                row_perm_bwd.into_boxed_slice(),
152            )
153        };
154
155        let mut inv_mat = faer::Mat::<T>::zeros(m, n);
156
157        faer::linalg::lu::partial_pivoting::inverse::inverse(
158            inv_mat.as_mut(),
159            l_mat,
160            u_mat,
161            perm.as_ref(),
162            par,
163            MemStack::new(&mut MemBuffer::new(
164                faer::linalg::lu::partial_pivoting::inverse::inverse_scratch::<usize, T>(m, par),
165            )),
166        );
167
168        for i in 0..m {
169            for j in 0..n {
170                a_faer[(i, j)] = inv_mat[(i, j)];
171            }
172        }
173
174        Ok(())
175    }
176
177    /// Computes the determinant of a square matrix. Panics if the matrix is non-square.
178    fn det<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> T {
179        let (m, n) = *a.shape();
180        assert_eq!(m, n, "determinant is only defined for square matrices");
181        let a_faer = into_faer_mut(a);
182        a_faer.determinant()
183    }
184
185    /// Computes the Cholesky decomposition, returning a lower-triangular matrix
186    fn choleski<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> InvResult<T> {
187        todo!("choleski will be implemented later")
188    }
189
190    /// Computes the Cholesky decomposition in-place, overwriting the input matrix
191    fn choleski_overwrite<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> Result<(), InvError> {
192        todo!("choleski_overwrite will be implemented later")
193    }
194}