blas_array2/blas1/
nrm2.rs

1use crate::ffi::{self, blas_int};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5use num_traits::Zero;
6
7/* #region BLAS func */
8
9pub trait NRM2Num: BLASFloat {
10    unsafe fn nrm2(n: *const blas_int, x: *const Self, incx: *const blas_int) -> Self::RealFloat;
11}
12
13macro_rules! impl_func {
14    ($type: ty, $func: ident) => {
15        impl NRM2Num for $type
16        where
17            $type: BLASFloat,
18        {
19            unsafe fn nrm2(
20                n: *const blas_int,
21                x: *const Self,
22                incx: *const blas_int,
23            ) -> <$type as BLASFloat>::RealFloat {
24                ffi::$func(n, x, incx)
25            }
26        }
27    };
28}
29
30impl_func!(f32, snrm2_);
31impl_func!(f64, dnrm2_);
32impl_func!(c32, scnrm2_);
33impl_func!(c64, dznrm2_);
34
35/* #endregion */
36
37/* #region BLAS driver */
38
39pub struct NRM2_Driver<'x, F>
40where
41    F: NRM2Num,
42{
43    n: blas_int,
44    x: ArrayView1<'x, F>,
45    incx: blas_int,
46}
47
48impl<'x, F> NRM2_Driver<'x, F>
49where
50    F: NRM2Num,
51{
52    pub fn run_blas(self) -> Result<F::RealFloat, BLASError> {
53        let Self { n, x, incx } = self;
54        let x_ptr = x.as_ptr();
55        if n == 0 {
56            return Ok(F::RealFloat::zero());
57        } else {
58            return unsafe { Ok(F::nrm2(&n, x_ptr, &incx)) };
59        }
60    }
61}
62
63/* #endregion */
64
65/* #region BLAS builder */
66
67#[derive(Builder)]
68#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
69pub struct NRM2_<'x, F>
70where
71    F: NRM2Num,
72{
73    pub x: ArrayView1<'x, F>,
74}
75
76impl<'x, F> NRM2_<'x, F>
77where
78    F: NRM2Num,
79{
80    pub fn driver(self) -> Result<NRM2_Driver<'x, F>, BLASError> {
81        let Self { x } = self;
82        let incx = x.stride_of(Axis(0));
83        let n = x.len_of(Axis(0));
84        let driver = NRM2_Driver { n: n.try_into()?, x, incx: incx.try_into()? };
85        return Ok(driver);
86    }
87}
88
89/* #region BLAS wrapper */
90
91pub type NRM2<'x, F> = NRM2_Builder<'x, F>;
92pub type SNRM2<'x> = NRM2<'x, f32>;
93pub type DNRM2<'x> = NRM2<'x, f64>;
94pub type SCNRM2<'x> = NRM2<'x, c32>;
95pub type DZNRM2<'x> = NRM2<'x, c64>;
96
97impl<'x, F> NRM2<'x, F>
98where
99    F: NRM2Num,
100{
101    pub fn run(self) -> Result<F::RealFloat, BLASError> {
102        self.build()?.driver()?.run_blas()
103    }
104}
105
106/* #endregion */