use ndarray::{ArrayBase, DataMut, Ix2, NdFloat};
use crate::{index::*, LinalgError, Result};
#[derive(Debug, Clone)]
pub struct GivensRotation<A> {
    c: A,
    s: A,
}
impl<A: NdFloat> GivensRotation<A> {
                    pub fn cancel_y(x: A, y: A) -> Option<(Self, A)> {
                if !y.is_zero() {
            let r = x.hypot(y);
            let c = x / r;
            let s = -y / r;
            Some((Self { c, s }, r))
        } else {
            None
        }
    }
                    pub fn cancel_x(x: A, y: A) -> Option<(Self, A)> {
        Self::cancel_y(y, x).map(|(mut rot, r)| {
            rot.s *= -A::one();
            (rot, r)
        })
    }
    pub fn identity() -> Self {
        Self {
            c: A::one(),
            s: A::zero(),
        }
    }
    pub fn try_new(c: A, s: A, eps: A) -> Option<(Self, A)> {
        let norm = c.hypot(s);
        if norm > eps {
            let c = c / norm;
            let s = s / norm;
            Some((Self { c, s }, norm))
        } else {
            None
        }
    }
    pub fn new(c: A, s: A) -> (Self, A) {
        Self::try_new(c, s, A::zero()).unwrap_or_else(|| (Self::identity(), A::zero()))
    }
    pub fn c(&self) -> A {
        self.c
    }
    pub fn s(&self) -> A {
        self.s
    }
        pub fn inverse(&self) -> Self {
        Self {
            c: self.c,
            s: -self.s,
        }
    }
        pub fn rotate_rows<S: DataMut<Elem = A>>(&self, lhs: &mut ArrayBase<S, Ix2>) -> Result<()> {
        let cols = lhs.ncols();
        if cols != 2 {
            return Err(LinalgError::WrongColumns {
                expected: 2,
                actual: cols,
            });
        }
        let c = self.c;
        let s = self.s;
        for j in 0..lhs.nrows() {
            unsafe {
                let a = *lhs.at((j, 0));
                let b = *lhs.at((j, 1));
                *lhs.atm((j, 0)) = a * c + s * b;
                *lhs.atm((j, 1)) = -s * a + b * c;
            }
        }
        Ok(())
    }
        pub fn rotate_cols<S: DataMut<Elem = A>>(&self, rhs: &mut ArrayBase<S, Ix2>) -> Result<()> {
        self.inverse()
            .rotate_rows(&mut rhs.view_mut().reversed_axes())
            .map_err(|err| match err {
                LinalgError::WrongColumns { expected, actual } => {
                    LinalgError::WrongRows { expected, actual }
                }
                err => err,
            })
    }
}
#[cfg(test)]
mod tests {
    use approx::assert_abs_diff_eq;
    use ndarray::array;
    use super::*;
    #[test]
    fn cancel_y() {
        let (rot, r) = GivensRotation::cancel_y(1.0f64, 2.0).unwrap();
        assert_abs_diff_eq!(r, 5.0_f64.sqrt());
        assert_abs_diff_eq!(rot.c, 0.4472136, epsilon = 1e-5);
        assert_abs_diff_eq!(rot.s, -0.8944272, epsilon = 1e-5);
        assert_abs_diff_eq!(
            array![[rot.c, -rot.s], [rot.s, rot.c]].dot(&array![1., 2.]),
            array![r, 0.]
        );
        assert!(GivensRotation::cancel_y(3.0f64, 0.).is_none());
    }
    #[test]
    fn cancel_x() {
        let (rot, r) = GivensRotation::cancel_x(1.0f64, 2.0).unwrap();
        assert_abs_diff_eq!(r, 5.0_f64.sqrt());
        assert_abs_diff_eq!(
            array![[rot.c, -rot.s], [rot.s, rot.c]].dot(&array![1., 2.]),
            array![0., r]
        );
        assert!(GivensRotation::cancel_y(3.0f64, 0.).is_none());
    }
    #[test]
    fn rotate_rows() {
        let (rot, _) = GivensRotation::cancel_y(1.0f64, 2.0).unwrap();
        let rows = array![[2., 3.], [4., 5.], [1., 2.], [3., 4.]];
        let mut out = rows.clone();
        rot.rotate_rows(&mut out).unwrap();
        assert_abs_diff_eq!(
            rows.dot(&array![[rot.c, -rot.s], [rot.s, rot.c]]),
            out,
            epsilon = 1e-5
        );
        assert!(matches!(
            rot.rotate_rows(&mut array![[1., 2., 3.]]).unwrap_err(),
            LinalgError::WrongColumns {
                expected: 2,
                actual: 3
            }
        ));
    }
    #[test]
    fn rotate_cols() {
        let (rot, _) = GivensRotation::cancel_y(1.0f64, 2.0).unwrap();
        let cols = array![[2., 3., 4.], [3., 4., 5.]];
        let mut out = cols.clone();
        rot.rotate_cols(&mut out).unwrap();
        assert_abs_diff_eq!(
            array![[rot.c, -rot.s], [rot.s, rot.c]].dot(&cols),
            out,
            epsilon = 1e-5
        );
        assert!(matches!(
            rot.rotate_cols(&mut array![[1., 2., 3.]]).unwrap_err(),
            LinalgError::WrongRows {
                expected: 2,
                actual: 1
            }
        ));
    }
}