rstsr_blas_traits/lapack_solve/
getrf.rs

1use crate::prelude_dev::*;
2use rstsr_core::prelude_dev::*;
3
4pub trait GETRFDriverAPI<T> {
5    unsafe fn driver_getrf(
6        order: FlagOrder,
7        m: usize,
8        n: usize,
9        a: *mut T,
10        lda: usize,
11        ipiv: *mut blas_int,
12    ) -> blas_int;
13}
14
15#[derive(Builder)]
16#[builder(pattern = "owned", no_std, build_fn(error = "Error"))]
17pub struct GETRF_<'a, B, T>
18where
19    T: BlasFloat,
20    B: DeviceAPI<T>,
21{
22    #[builder(setter(into))]
23    pub a: TensorReference<'a, T, B, Ix2>,
24}
25
26impl<'a, B, T> GETRF_<'a, B, T>
27where
28    T: BlasFloat,
29    B: BlasDriverBaseAPI<T> + GETRFDriverAPI<T>,
30{
31    pub fn internal_run(self) -> Result<(TensorMutable2<'a, T, B>, Tensor<blas_int, B, Ix1>)> {
32        let Self { a } = self;
33
34        let device = a.device().clone();
35        let mut a = overwritable_convert(a)?;
36        let order = if a.f_prefer() && !a.c_prefer() { ColMajor } else { RowMajor };
37
38        let [m, n] = *a.view().shape();
39        let lda = a.view().ld(order).unwrap();
40        let mut ipiv = unsafe { empty_f(([n].c(), &device))?.into_dim::<Ix1>() };
41        let ptr_a = a.view_mut().as_mut_ptr();
42        let ptr_ipiv = ipiv.as_mut_ptr();
43
44        // run driver
45        let info = unsafe { B::driver_getrf(order, m, n, ptr_a, lda, ptr_ipiv) };
46        let info = info as i32;
47        if info != 0 {
48            rstsr_errcode!(info, "Lapack GETRF")?;
49        }
50
51        // rust is 0-indexed
52        ipiv -= 1;
53
54        Ok((a.clone_to_mut(), ipiv))
55    }
56
57    pub fn run(self) -> Result<(TensorMutable2<'a, T, B>, Tensor<blas_int, B, Ix1>)> {
58        self.internal_run()
59    }
60}
61
62pub type GETRF<'a, B, T> = GETRF_Builder<'a, B, T>;
63pub type SGETRF<'a, B> = GETRF<'a, B, f32>;
64pub type DGETRF<'a, B> = GETRF<'a, B, f64>;
65pub type CGETRF<'a, B> = GETRF<'a, B, Complex<f32>>;
66pub type ZGETRF<'a, B> = GETRF<'a, B, Complex<f64>>;