rstsr_blas_traits/lapack_solve/
getrf.rs1use crate::prelude_dev::*;
2use rstsr_core::prelude_dev::*;
3
4pub trait GETRFDriverAPI<T> {
5 unsafe fn driver_getrf(
6 order: FlagOrder,
7 m: usize,
8 n: usize,
9 a: *mut T,
10 lda: usize,
11 ipiv: *mut blas_int,
12 ) -> blas_int;
13}
14
15#[derive(Builder)]
16#[builder(pattern = "owned", no_std, build_fn(error = "Error"))]
17pub struct GETRF_<'a, B, T>
18where
19 T: BlasFloat,
20 B: DeviceAPI<T>,
21{
22 #[builder(setter(into))]
23 pub a: TensorReference<'a, T, B, Ix2>,
24}
25
26impl<'a, B, T> GETRF_<'a, B, T>
27where
28 T: BlasFloat,
29 B: BlasDriverBaseAPI<T> + GETRFDriverAPI<T>,
30{
31 pub fn internal_run(self) -> Result<(TensorMutable2<'a, T, B>, Tensor<blas_int, B, Ix1>)> {
32 let Self { a } = self;
33
34 let device = a.device().clone();
35 let mut a = overwritable_convert(a)?;
36 let order = if a.f_prefer() && !a.c_prefer() { ColMajor } else { RowMajor };
37
38 let [m, n] = *a.view().shape();
39 let lda = a.view().ld(order).unwrap();
40 let mut ipiv = unsafe { empty_f(([n].c(), &device))?.into_dim::<Ix1>() };
41 let ptr_a = a.view_mut().as_mut_ptr();
42 let ptr_ipiv = ipiv.as_mut_ptr();
43
44 let info = unsafe { B::driver_getrf(order, m, n, ptr_a, lda, ptr_ipiv) };
46 let info = info as i32;
47 if info != 0 {
48 rstsr_errcode!(info, "Lapack GETRF")?;
49 }
50
51 ipiv -= 1;
53
54 Ok((a.clone_to_mut(), ipiv))
55 }
56
57 pub fn run(self) -> Result<(TensorMutable2<'a, T, B>, Tensor<blas_int, B, Ix1>)> {
58 self.internal_run()
59 }
60}
61
62pub type GETRF<'a, B, T> = GETRF_Builder<'a, B, T>;
63pub type SGETRF<'a, B> = GETRF<'a, B, f32>;
64pub type DGETRF<'a, B> = GETRF<'a, B, f64>;
65pub type CGETRF<'a, B> = GETRF<'a, B, Complex<f32>>;
66pub type ZGETRF<'a, B> = GETRF<'a, B, Complex<f64>>;