Documentation
use core::ops::{Add, Div, Mul, Sub};

use num_traits::{Float, FromPrimitive, One, Zero};
use vec3;

/// # Example
/// ```
/// let mut m = mat4::new_identity();
/// let position = [0f32, 0f32, 0f32];
/// let scale = [1f32, 1f32, 1f32];
/// let rotation = [0f32, 0f32, 0f32, 1f32];
/// mat4::compose(&mut m, &position, &scale, &rotation);
/// assert_eq!(m, mat4::new_identity());
/// ```
#[inline]
pub fn compose<'out, T>(
    out: &'out mut [T; 16],
    position: &[T; 3],
    scale: &[T; 3],
    rotation: &[T; 4],
) -> &'out mut [T; 16]
where
    T: Clone + Zero + One + Add<T, Output = T> + Sub<T, Output = T>,
    for<'a, 'b> &'a T: Mul<&'b T, Output = T> + Add<&'b T, Output = T> + Sub<&'b T, Output = T>,
{
    let x = &rotation[0];
    let y = &rotation[1];
    let z = &rotation[2];
    let w = &rotation[3];
    let x2 = x + x;
    let y2 = y + y;
    let z2 = z + z;
    let xx = x * &x2;
    let xy = x * &y2;
    let xz = x * &z2;
    let yy = y * &y2;
    let yz = y * &z2;
    let zz = z * &z2;
    let wx = w * &x2;
    let wy = w * &y2;
    let wz = w * &z2;

    let sx = &scale[0];
    let sy = &scale[1];
    let sz = &scale[2];

    out[0] = &(T::one() - (&yy + &zz)) * &sx;
    out[4] = &(&xy - &wz) * &sy;
    out[8] = &(&xz + &wy) * &sz;

    out[1] = &(&xy + &wz) * &sx;
    out[5] = &(T::one() - (&xx + &zz)) * &sy;
    out[9] = &(&yz - &wx) * &sz;

    out[2] = &(&xz - &wy) * &sx;
    out[6] = &(&yz + &wx) * &sy;
    out[10] = &(T::one() - (&xx + &yy)) * &sz;

    out[3] = T::zero();
    out[7] = T::zero();
    out[11] = T::zero();

    out[12] = position[0].clone();
    out[13] = position[1].clone();
    out[14] = position[2].clone();
    out[15] = T::one();
    out
}

/// # Example
/// ```
/// let m = mat4::new_identity();
/// let mut position = [0f32, 0f32, 0f32];
/// let mut scale = [1f32, 1f32, 1f32];
/// let mut rotation = [0f32, 0f32, 0f32, 1f32];
/// mat4::decompose(&m, &mut position, &mut scale, &mut rotation);
/// assert_eq!(position, [0f32, 0f32, 0f32]);
/// assert_eq!(scale, [1f32, 1f32, 1f32]);
/// assert_eq!(rotation, [0f32, 0f32, 0f32, 1f32]);
/// ```
#[inline]
pub fn decompose<T>(out: &[T; 16], position: &mut [T; 3], scale: &mut [T; 3], rotation: &mut [T; 4])
where
    T: Clone
        + Float
        + Zero
        + FromPrimitive
        + One
        + Add<T, Output = T>
        + Sub<T, Output = T>
        + PartialOrd,
    for<'a, 'b> &'a T: Mul<&'b T, Output = T>
        + Div<&'b T, Output = T>
        + Add<&'b T, Output = T>
        + Sub<&'b T, Output = T>,
{
    let m11 = &out[0];
    let m12 = &out[4];
    let m13 = &out[8];
    let m21 = &out[1];
    let m22 = &out[5];
    let m23 = &out[9];
    let m31 = &out[2];
    let m32 = &out[6];
    let m33 = &out[10];

    let sx = vec3::len_values(m11, m21, m31);
    let sy = vec3::len_values(m12, m22, m32);
    let sz = vec3::len_values(m13, m23, m33);

    scale[0] = sx;
    scale[1] = sy;
    scale[2] = sz;

    position[0] = out[12].clone();
    position[1] = out[13].clone();
    position[2] = out[14].clone();

    let trace = &(m11 + m22) + m33;

    if &trace > &T::zero() {
        let s = &T::from_f32(0.5).unwrap() / &(trace + T::one()).sqrt();
        let inv_s = if s.is_zero() {
            s.clone()
        } else {
            &T::one() / &s
        };

        rotation[0] = &(m32 - m23) * &s;
        rotation[1] = &(m13 - m31) * &s;
        rotation[2] = &(m21 - m12) * &s;
        rotation[3] = &T::from_f32(0.25).unwrap() * &inv_s;
    } else if m11 > m22 && m11 > m33 {
        let s = &T::from_isize(2).unwrap()
            * &(T::one() + m11.clone() - m22.clone() - m33.clone()).sqrt();
        let inv_s = if s.is_zero() {
            s.clone()
        } else {
            &T::one() / &s
        };

        rotation[0] = &T::from_f32(0.25).unwrap() * &s;
        rotation[1] = &(m12 + m21) * &inv_s;
        rotation[2] = &(m13 + m31) * &inv_s;
        rotation[3] = &(m32 - m23) * &inv_s;
    } else if m22 > m33 {
        let s = &T::from_isize(2).unwrap()
            * &(T::one() + m22.clone() - m11.clone() - m33.clone()).sqrt();
        let inv_s = if s.is_zero() {
            s.clone()
        } else {
            &T::one() / &s
        };

        rotation[0] = &(m12 + m21) * &inv_s;
        rotation[1] = &T::from_f32(0.25).unwrap() * &s;
        rotation[2] = &(m23 + m32) * &inv_s;
        rotation[3] = &(m13 - m31) * &inv_s;
    } else {
        let s =
            T::from_isize(2).unwrap() * (T::one() + m33.clone() - m11.clone() - m22.clone()).sqrt();
        let inv_s = if s.is_zero() {
            s.clone()
        } else {
            &T::one() / &s
        };

        rotation[0] = &(m13 + m31) * &inv_s;
        rotation[1] = &(m23 + m32) * &inv_s;
        rotation[2] = &T::from_f32(0.25).unwrap() * &s;
        rotation[3] = &(m21 - m12) * &inv_s;
    }
}