ndarray-linalg 0.16.0

Linear algebra package for rust-ndarray using LAPACK
//! # Least Squares
//! Compute a least-squares solution to the equation Ax = b.
//! Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
//! Finding the least squares solutions is implemented as traits, meaning
//! that to solve `A x = b` for a matrix `A` and a RHS `b`, we call
//! `let result = A.least_squares(&b);`. This returns a `result` of
//! type `LeastSquaresResult`, the solution for the least square problem
//! is in `result.solution`.
//! There are three traits, `LeastSquaresSvd` with the method `least_squares`,
//! which operates on immutable references, `LeastSquaresInto` with the method
//! `least_squares_into`, which takes ownership over both the array `A` and the
//! RHS `b` and `LeastSquaresSvdInPlace` with the method `least_squares_in_place`,
//! which operates on mutable references for `A` and `b` and destroys these when
//! solving the least squares problem. `LeastSquaresSvdInto` and
//! `LeastSquaresSvdInPlace` avoid an extra allocation for `A` and `b` which
//! `LeastSquaresSvd` has do perform to preserve the values in `A` and `b`.
//! All methods use the Lapacke family of methods `*gelsd` which solves the least
//! squares problem using the SVD with a divide-and-conquer strategy.
//! The traits are implemented for value types `f32`, `f64`, `c32` and `c64`
//! and vector or matrix right-hand-sides (`ArrayBase<S, Ix1>` or `ArrayBase<S, Ix2>`).
//! ## Example
//! ```rust
//! use approx::AbsDiffEq; // for abs_diff_eq
//! use ndarray::{array, Array1, Array2};
//! use ndarray_linalg::{LeastSquaresSvd, LeastSquaresSvdInto, LeastSquaresSvdInPlace};
//! let a: Array2<f64> = array![
//!     [1., 1., 1.],
//!     [2., 3., 4.],
//!     [3., 5., 2.],
//!     [4., 2., 5.],
//!     [5., 4., 3.]
//! ];
//! // solving for a single right-hand side
//! let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
//! let expected: Array1<f64> = array![2., 1., 1.];
//! let result = a.least_squares(&b).unwrap();
//! assert!(result.solution.abs_diff_eq(&expected, 1e-12));
//! // solving for two right-hand sides at once
//! let b_2: Array2<f64> =
//!     array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
//! let expected_2: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
//! let result_2 = a.least_squares(&b_2).unwrap();
//! assert!(result_2.solution.abs_diff_eq(&expected_2, 1e-12));
//! // using `least_squares_in_place` which overwrites its arguments
//! let mut a_3 = a.clone();
//! let mut b_3 = b.clone();
//! let result_3 = a_3.least_squares_in_place(&mut b_3).unwrap();
//! // using `least_squares_into` which consumes its arguments
//! let result_4 = a.least_squares_into(b).unwrap();
//! // `a` and `b` have been moved, no longer valid
//! ```

use lax::*;
use ndarray::*;

use crate::error::*;
use crate::layout::*;
use crate::types::*;

/// Result of a LeastSquares computation
/// Takes two type parameters, `E`, the element type of the matrix
/// (one of `f32`, `f64`, `c32` or `c64`) and `I`, the dimension of
/// b in the equation `Ax = b` (one of `Ix1` or `Ix2`). If `I` is `Ix1`,
/// the  right-hand-side (RHS) is a `n x 1` column vector and the solution
/// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix
/// (which can be seen as solving `Ax = b` k times for different b) and
/// the solution is a `m x k` matrix.
#[derive(Debug, Clone)]
pub struct LeastSquaresResult<E: Scalar, I: Dimension> {
    /// The singular values of the matrix A in `Ax = b`
    pub singular_values: Array1<E::Real>,
    /// The solution vector or matrix `x` which is the best
    /// solution to `Ax = b`, i.e. minimizing the 2-norm `||b - Ax||`
    pub solution: Array<E, I>,
    /// The rank of the matrix A in `Ax = b`
    pub rank: i32,
    /// If n < m and rank(A) == n, the sum of squares
    /// If b is a (m x 1) vector, this is a 0-dimensional array (single value)
    /// If b is a (m x k) matrix, this is a (k x 1) column vector
    pub residual_sum_of_squares: Option<Array<E::Real, I::Smaller>>,
/// Solve least squares for immutable references
pub trait LeastSquaresSvd<D, E, I>
    D: Data<Elem = E>,
    E: Scalar + Lapack,
    I: Dimension,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(&rhs)`. `A` and `rhs`
    /// are unchanged.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares(&self, rhs: &ArrayBase<D, I>) -> Result<LeastSquaresResult<E, I>>;

/// Solve least squares for owned matrices
pub trait LeastSquaresSvdInto<D, E, I>
    D: Data<Elem = E>,
    E: Scalar + Lapack,
    I: Dimension,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(rhs)`, consuming both `A`
    /// and `rhs`. This uses the memory location of `A` and
    /// `rhs`, which avoids some extra memory allocations.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares_into(self, rhs: ArrayBase<D, I>) -> Result<LeastSquaresResult<E, I>>;

/// Solve least squares for mutable references, overwriting
/// the input fields in the process
pub trait LeastSquaresSvdInPlace<D, E, I>
    D: Data<Elem = E>,
    E: Scalar + Lapack,
    I: Dimension,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(&mut rhs)`, overwriting both `A`
    /// and `rhs`. This uses the memory location of `A` and
    /// `rhs`, which avoids some extra memory allocations.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares_in_place(
        &mut self,
        rhs: &mut ArrayBase<D, I>,
    ) -> Result<LeastSquaresResult<E, I>>;

/// Solve least squares for immutable references and a single
/// column vector as a right-hand side.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix1> for ArrayBase<D1, Ix2>
    E: Scalar + Lapack,
    D1: Data<Elem = E>,
    D2: Data<Elem = E>,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(&rhs)`, where `rhs` is a
    /// single column vector. `A` and `rhs` are unchanged.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares(&self, rhs: &ArrayBase<D2, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
        let a = self.to_owned();
        let b = rhs.to_owned();

/// Solve least squares for immutable references and matrix
/// (=mulitipe vectors) as a right-hand side.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix2> for ArrayBase<D1, Ix2>
    E: Scalar + Lapack,
    D1: Data<Elem = E>,
    D2: Data<Elem = E>,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(&rhs)`, where `rhs` is
    /// matrix. `A` and `rhs` are unchanged.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares(&self, rhs: &ArrayBase<D2, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
        let a = self.to_owned();
        let b = rhs.to_owned();

/// Solve least squares for owned values and a single
/// column vector as a right-hand side. The matrix and the RHS
/// vector are consumed.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix1> for ArrayBase<D1, Ix2>
    E: Scalar + Lapack,
    D1: DataMut<Elem = E>,
    D2: DataMut<Elem = E>,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(rhs)`, where `rhs` is a
    /// single column vector. `A` and `rhs` are consumed.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares_into(
        mut self,
        mut rhs: ArrayBase<D2, Ix1>,
    ) -> Result<LeastSquaresResult<E, Ix1>> {
        self.least_squares_in_place(&mut rhs)

/// Solve least squares for owned values and a matrix
/// as a right-hand side. The matrix and the RHS matrix
/// are consumed.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix2> for ArrayBase<D1, Ix2>
    E: Scalar + Lapack,
    D1: DataMut<Elem = E>,
    D2: DataMut<Elem = E>,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(rhs)`, where `rhs` is a
    /// matrix. `A` and `rhs` are consumed.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares_into(
        mut self,
        mut rhs: ArrayBase<D2, Ix2>,
    ) -> Result<LeastSquaresResult<E, Ix2>> {
        self.least_squares_in_place(&mut rhs)

/// Solve least squares for mutable references and a vector
/// as a right-hand side. Both values are overwritten in the
/// call.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix1> for ArrayBase<D1, Ix2>
    E: Scalar + Lapack,
    D1: DataMut<Elem = E>,
    D2: DataMut<Elem = E>,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(rhs)`, where `rhs` is a
    /// vector. `A` and `rhs` are overwritten in the call.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares_in_place(
        &mut self,
        rhs: &mut ArrayBase<D2, Ix1>,
    ) -> Result<LeastSquaresResult<E, Ix1>> {
        if self.shape()[0] != rhs.shape()[0] {
            return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
        let (m, n) = (self.shape()[0], self.shape()[1]);
        if n > m {
            // we need a new rhs b/c it will be overwritten with the solution
            // for which we need `n` entries
            let mut new_rhs = Array1::<E>::zeros((n,));
            compute_least_squares_srhs(self, &mut new_rhs)
        } else {
            compute_least_squares_srhs(self, rhs)

fn compute_least_squares_srhs<E, D1, D2>(
    a: &mut ArrayBase<D1, Ix2>,
    rhs: &mut ArrayBase<D2, Ix1>,
) -> Result<LeastSquaresResult<E, Ix1>>
    E: Scalar + Lapack,
    D1: DataMut<Elem = E>,
    D2: DataMut<Elem = E>,
    let LeastSquaresOwned::<E> {
    } = E::least_squares(

    let (m, n) = (a.shape()[0], a.shape()[1]);
    let solution = rhs.slice(s![0..n]).to_owned();
    let residual_sum_of_squares = compute_residual_scalar(m, n, rank, rhs);
    Ok(LeastSquaresResult {
        singular_values: Array::from_shape_vec((singular_values.len(),), singular_values)?,

fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
    m: usize,
    n: usize,
    rank: i32,
    b: &ArrayBase<D, Ix1>,
) -> Option<Array<E::Real, Ix0>> {
    if m < n || n != rank as usize {
        return None;
    let mut arr: Array<E::Real, Ix0> = Array::zeros(());
    arr[()] = b.slice(s![n..]).mapv(|x| x.powi(2).abs()).sum();

/// Solve least squares for mutable references and a matrix
/// as a right-hand side. Both values are overwritten in the
/// call.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix2> for ArrayBase<D1, Ix2>
    E: Scalar + Lapack,
    D1: DataMut<Elem = E>,
    D2: DataMut<Elem = E>,
    /// Solve a least squares problem of the form `Ax = rhs`
    /// by calling `A.least_squares(rhs)`, where `rhs` is a
    /// matrix. `A` and `rhs` are overwritten in the call.
    /// `A` and `rhs` must have the same layout, i.e. they must
    /// be both either row- or column-major format, otherwise a
    /// `IncompatibleShape` error is raised.
    fn least_squares_in_place(
        &mut self,
        rhs: &mut ArrayBase<D2, Ix2>,
    ) -> Result<LeastSquaresResult<E, Ix2>> {
        if self.shape()[0] != rhs.shape()[0] {
            return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
        let (m, n) = (self.shape()[0], self.shape()[1]);
        if n > m {
            // we need a new rhs b/c it will be overwritten with the solution
            // for which we need `n` entries
            let k = rhs.shape()[1];
            let mut new_rhs = match self.layout()? {
                MatrixLayout::C { .. } => Array2::<E>::zeros((n, k)),
                MatrixLayout::F { .. } => Array2::<E>::zeros((n, k).f()),
            new_rhs.slice_mut(s![0..m, ..]).assign(rhs);
            compute_least_squares_nrhs(self, &mut new_rhs)
        } else {
            compute_least_squares_nrhs(self, rhs)

fn compute_least_squares_nrhs<E, D1, D2>(
    a: &mut ArrayBase<D1, Ix2>,
    rhs: &mut ArrayBase<D2, Ix2>,
) -> Result<LeastSquaresResult<E, Ix2>>
    E: Scalar + Lapack,
    D1: DataMut<Elem = E>,
    D2: DataMut<Elem = E>,
    let a_layout = a.layout()?;
    let rhs_layout = rhs.layout()?;
    let LeastSquaresOwned::<E> {
    } = E::least_squares_nrhs(

    let solution: Array2<E> = rhs.slice(s![..a.shape()[1], ..]).to_owned();
    let singular_values = Array::from_shape_vec((singular_values.len(),), singular_values)?;
    let (m, n) = (a.shape()[0], a.shape()[1]);
    let residual_sum_of_squares = compute_residual_array1(m, n, rank, rhs);
    Ok(LeastSquaresResult {

fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(
    m: usize,
    n: usize,
    rank: i32,
    b: &ArrayBase<D, Ix2>,
) -> Option<Array1<E::Real>> {
    if m < n || n != rank as usize {
        return None;
        b.slice(s![n.., ..])
            .mapv(|x| x.powi(2).abs())

mod tests {
    use crate::{error::LinalgError, *};
    use approx::AbsDiffEq;
    use ndarray::*;

    // Test that the different least squares traits work as intended on the
    // different array types.
    //               | least_squares | ls_into | ls_in_place |
    // --------------+---------------+---------+-------------+
    // Array         | yes           | yes     | yes         |
    // ArcArray      | yes           | no      | no          |
    // CowArray      | yes           | yes     | yes         |
    // ArrayView     | yes           | no      | no          |
    // ArrayViewMut  | yes           | no      | yes         |

    fn assert_result<D1: Data<Elem = f64>, D2: Data<Elem = f64>>(
        a: &ArrayBase<D1, Ix2>,
        b: &ArrayBase<D2, Ix1>,
        res: &LeastSquaresResult<f64, Ix1>,
    ) {
        assert_eq!(res.rank, 2);
        let b_hat = a.dot(&res.solution);
        let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum();
        assert!(res.residual_sum_of_squares.as_ref().unwrap()[()].abs_diff_eq(&rssq, 1e-12));
            .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));

    fn on_arc() {
        let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
        let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
        let res = a.least_squares(&b).unwrap();
        assert_result(&a, &b, &res);

    fn on_cow() {
        let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
        let b = CowArray::from(array![1., 2., 3.]);
        let res = a.least_squares(&b).unwrap();
        assert_result(&a, &b, &res);

    fn on_view() {
        let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
        let b: Array1<f64> = array![1., 2., 3.];
        let av = a.view();
        let bv = b.view();
        let res = av.least_squares(&bv).unwrap();
        assert_result(&av, &bv, &res);

    fn on_view_mut() {
        let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
        let mut b: Array1<f64> = array![1., 2., 3.];
        let av = a.view_mut();
        let bv = b.view_mut();
        let res = av.least_squares(&bv).unwrap();
        assert_result(&av, &bv, &res);

    fn on_cow_view() {
        let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
        let b: Array1<f64> = array![1., 2., 3.];
        let bv = b.view();
        let res = a.least_squares(&bv).unwrap();
        assert_result(&a, &bv, &res);

    fn into_on_owned() {
        let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
        let b: Array1<f64> = array![1., 2., 3.];
        let ac = a.clone();
        let bc = b.clone();
        let res = ac.least_squares_into(bc).unwrap();
        assert_result(&a, &b, &res);

    fn into_on_arc() {
        let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
        let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
        let a2 = a.clone();
        let b2 = b.clone();
        let res = a2.least_squares_into(b2).unwrap();
        assert_result(&a, &b, &res);

    fn into_on_cow() {
        let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
        let b = CowArray::from(array![1., 2., 3.]);
        let a2 = a.clone();
        let b2 = b.clone();
        let res = a2.least_squares_into(b2).unwrap();
        assert_result(&a, &b, &res);

    fn into_on_owned_cow() {
        let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
        let b = CowArray::from(array![1., 2., 3.]);
        let ac = a.clone();
        let b2 = b.clone();
        let res = ac.least_squares_into(b2).unwrap();
        assert_result(&a, &b, &res);

    fn in_place_on_owned() {
        let a = array![[1., 2.], [4., 5.], [3., 4.]];
        let b = array![1., 2., 3.];
        let mut a2 = a.clone();
        let mut b2 = b.clone();
        let res = a2.least_squares_in_place(&mut b2).unwrap();
        assert_result(&a, &b, &res);

    fn in_place_on_cow() {
        let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
        let b = CowArray::from(array![1., 2., 3.]);
        let mut a2 = a.clone();
        let mut b2 = b.clone();
        let res = a2.least_squares_in_place(&mut b2).unwrap();
        assert_result(&a, &b, &res);

    fn in_place_on_mut_view() {
        let a = array![[1., 2.], [4., 5.], [3., 4.]];
        let b = array![1., 2., 3.];
        let mut a2 = a.clone();
        let mut b2 = b.clone();
        let av = &mut a2.view_mut();
        let bv = &mut b2.view_mut();
        let res = av.least_squares_in_place(bv).unwrap();
        assert_result(&a, &b, &res);

    fn in_place_on_owned_cow() {
        let a = array![[1., 2.], [4., 5.], [3., 4.]];
        let b = CowArray::from(array![1., 2., 3.]);
        let mut a2 = a.clone();
        let mut b2 = b.clone();
        let res = a2.least_squares_in_place(&mut b2).unwrap();
        assert_result(&a, &b, &res);

    // Testing error cases
    fn incompatible_shape_error_on_mismatching_num_rows() {
        let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
        let b: Array1<f64> = array![1., 2.];
        match a.least_squares(&b) {
            Err(LinalgError::Shape(e)) if e.kind() == ErrorKind::IncompatibleShape => {}
            _ => panic!("Should be raise IncompatibleShape"),