use std::fmt::Display;
use std::ops::{Add, Mul, Sub};
use crate::util::{ErrorKind, NmlError};
use rand::Rng;
use crate::util::ErrorKind::{CreateMatrix, InvalidCols, InvalidRows};
#[derive(Debug)]
pub struct NmlMatrix {
    pub num_rows: u32,
    pub num_cols: u32,
    pub data: Vec<f64>,
    pub is_square: bool,
}
impl NmlMatrix {
    pub fn new(num_rows: u32, num_cols: u32) -> Self {
        let data: Vec<f64> = Vec::with_capacity((num_rows * num_cols) as usize);
        Self {
            num_rows,
            num_cols,
            data,
            is_square: num_rows == num_cols,
        }
    }
    pub fn new_with_2d_vec(num_rows: u32, num_cols: u32, data_2d: &mut Vec<Vec<f64>>) -> Result<Self, NmlError> {
        let rows: u32 = data_2d.len() as u32;
        let mut cols_match = true;
        let mut data: Vec<f64> = Vec::with_capacity((num_rows*num_cols) as usize);
        for element in data_2d {
            if element.len() as u32 != num_cols {
                cols_match = false;
                break;
            }
            data.append(element);
        }
        match cols_match && rows == num_rows{
            true => {Ok(Self{
                num_cols,
                num_rows,
                data,
                is_square: num_rows == num_rows
            })},
            false => {Err(NmlError::new(ErrorKind::CreateMatrix))}
        }
    }
    pub fn new_with_data(num_rows: u32, num_cols: u32, data: Vec<f64>) -> Result<Self, NmlError> {
        match (num_rows * num_cols) as usize == data.len() {
            false => Err(NmlError::new(ErrorKind::CreateMatrix)),
            true => {
                let is_square = num_rows == num_cols;
                Ok(NmlMatrix {
                    num_rows,
                    num_cols,
                    data,
                    is_square,
                })
            },
        }
    }
    pub fn nml_mat_rnd(num_rows: u32, num_cols: u32, minimum: f64, maximum: f64) -> Self {
        let mut rng = rand::thread_rng();
        let random_numbers: Vec<f64> = (0..100).map(|_| (rng.gen_range(minimum..maximum))).collect();
        Self {
            num_rows,
            num_cols,
            data: random_numbers,
            is_square: num_rows == num_cols,
        }
    }
    pub fn nml_mat_sqr(size: u32) -> Self {
        Self {
            num_rows: size,
            num_cols: size,
            data: vec![0.0; (size * size) as usize],
            is_square: true,
        }
    }
    pub fn nml_mat_eye(size: u32) -> Self {
        let mut data: Vec<f64> = vec![0.0; (size * size) as usize];
        for i in 0..size {
            data[(i * size + i) as usize] = 1.0;
        }
        Self {
            num_rows: size,
            num_cols: size,
            data,
            is_square: true,
        }
    }
    pub fn nml_mat_cp(matrix: &NmlMatrix) -> Self {
        Self {
            num_rows: matrix.num_rows,
            num_cols: matrix.num_cols,
            data: matrix.data.clone(),
            is_square: matrix.is_square,
        }
    }
    pub fn nml_mat_from_file() -> Self {
        unimplemented!("Not implemented yet")
    }
    pub fn equality(self: &Self, matrix: &NmlMatrix) -> bool {
        if self.num_rows != matrix.num_rows || self.num_cols != matrix.num_cols {
            return false;
        }
        for i in 0..self.num_rows {
            for j in 0..self.num_cols {
                if self.data[(i * self.num_cols + j) as usize] != matrix.data[(i * matrix.num_cols + j) as usize] {
                    return false;
                }
            }
        }
        true
    }
    pub fn equality_in_tolerance(self: &Self, matrix: NmlMatrix, tolerance: f64) -> bool {
        if self.num_rows != matrix.num_rows || self.num_cols != matrix.num_cols {
            return false;
        }
        for i in 0..self.num_rows {
            for j in 0..self.num_cols {
                if (self.data[(i * self.num_cols + j) as usize] - matrix.data[(i * matrix.num_cols + j) as usize]).abs() > tolerance {
                    return false;
                }
            }
        }
        true
    }
    pub fn at(self: &Self, row: u32, col: u32) -> Result<f64, NmlError> {
        match row < self.num_rows as u32 && col < self.num_cols as u32 {
            false => Err(NmlError::new(InvalidRows)),
            true => Ok(self.data[(row * self.num_cols as u32 + col) as usize]),
        }
    }
    pub fn get_column(self: &Self, column: u32) -> Result<Self, NmlError> {
        match column < self.num_cols {
            false => Err(NmlError::new(InvalidCols)),
            true => {
                let mut data: Vec<f64> = Vec::with_capacity(self.num_rows as usize);
                for i in 0..self.num_rows {
                    data.push(self.data[(i * self.num_rows + column) as usize]);
                }
                Ok(Self {
                    num_cols: 1,
                    num_rows: self.num_rows,
                    data,
                    is_square: false
                })
            },
        }
    }
    pub fn get_row(self: &Self, row: u32) -> Result<Self, NmlError> {
        match row < self.num_rows {
            true => {
                let data: Vec<f64> = self.data[(row * self.num_cols) as usize..(row * self.num_cols + self.num_cols) as usize].to_vec().clone();
                Ok(Self {
                    num_cols: self.num_cols,
                    num_rows: 1,
                    data,
                    is_square: false
                })
            },
            false => Err(NmlError::new(InvalidRows)),
        }
    }
    pub fn set_value(self: &mut Self, row: u32, col: u32, data: f64) -> Result<(), NmlError> {
        let valid_tuple: (bool, bool) = (row < self.num_rows, col < self.num_cols);
        match valid_tuple {
            (false, _) => Err(NmlError::new(InvalidRows)),
            (_, false) => Err(NmlError::new(InvalidCols)),
            (true, true) => {
                self.data[(row * self.num_cols + col) as usize] = data;
                Ok(())
            },
        }
    }
    pub fn set_all_values(self: &mut Self, value: f64) {
        for i in 0..self.num_rows {
            for j in 0..self.num_cols {
                self.data[(i * self.num_cols + j) as usize] = value;
            }
        }
    }
    pub fn set_dig_values(self: &mut Self, value: f64) -> Result<(), NmlError> {
        if self.is_square == true {
            for i in 0..self.num_rows {
                self.data[(i * self.num_cols + i) as usize] = value;
            }
        }
        match self.is_square {
            true => Ok(()),
            false => Err(NmlError::new(ErrorKind::MatrixNotSquare)),
        }
    }
    pub fn multiply_row_scalar(self: &mut Self, row: u32, scalar: f64) -> Result<(), NmlError> {
        match row < self.num_rows {
            false => Err(NmlError::new(InvalidRows)),
            true => {
                for i in 0..self.num_cols {
                    self.data[(row * self.num_cols + i) as usize] *= scalar;
                }
                Ok(())
            },
        }
    }
    pub fn multiply_col_scalar(self: &mut Self, col: u32, scalar: f64) -> Result<(), NmlError> {
        match col < self.num_cols {
            false => Err(NmlError::new(InvalidCols)),
            true => {
                for i in 0..self.num_rows {
                    self.data[(i * self.num_cols + col) as usize] *= scalar;
                }
                Ok(())
            }
        }
    }
    pub fn multiply_matrix_scalar(self: &mut Self, scalar: f64) {
        for i in 0..self.data.len() {
            self.data[i] *= scalar;
        }
    }
    pub fn add_rows(self: &mut Self, row_1: u32, scalar_1: f64, row_2: u32, scalar_2: f64) -> Result<(), NmlError> {
        match row_1 < self.num_rows && row_2 < self.num_rows {
            false => Err(NmlError::new(InvalidRows)),
            true => {
                for i in 0..self.num_cols {
                    let value = self.data[(row_1 * self.num_cols + i) as usize];
                    self.data[(row_1 * self.num_cols + i) as usize] = value * scalar_1 + self.data[(row_2 * self.num_cols + i) as usize] * scalar_2;
                }
                Ok(())
            }
        }
    }
    pub fn swap_rows(self: &mut Self, row_1: u32, row_2: u32) -> Result<(), NmlError> {
        match row_1 < self.num_rows && row_2 < self.num_rows {
            false => Err(NmlError::new(InvalidRows)),
            true => {
                for i in 0..self.num_cols {
                    let temp = self.data[(row_1 * self.num_cols + i) as usize];
                    self.data[(row_1 * self.num_cols + i) as usize] = self.data[(row_2 * self.num_cols + i) as usize];
                    self.data[(row_2 * self.num_cols + i) as usize] = temp;
                }
                Ok(())
            }
        }
    }
    pub fn swap_columns(self: &mut Self, col_1: u32, col_2: u32) -> Result<(), NmlError> {
        match col_1 < self.num_cols && col_2 < self.num_cols {
            false => Err(NmlError::new(InvalidCols)),
            true => {
                for i in 0..self.num_rows {
                    let temp = self.data[(i * self.num_cols + col_1) as usize];
                    self.data[(i * self.num_cols + col_1) as usize] = self.data[(i * self.num_cols + col_2) as usize];
                    self.data[(i * self.num_cols + col_2) as usize] = temp;
                }
                Ok(())
            }
        }
    }
    pub fn remove_column(self: &Self, col: u32) -> Result<Self, NmlError> {
        match col < self.num_cols {
            false => Err(NmlError::new(InvalidCols)),
            true => {
                let mut data: Vec<f64> = Vec::with_capacity(self.data.len());
                let indexes: Vec<usize> = (col as usize..self.data.len()).step_by(self.num_cols as usize).collect();
                for i in 0..self.data.len() {
                    if !indexes.contains(&i) {
                        data.push(self.data[i]);
                    }
                }
                Ok(
                    Self {
                        num_cols: 1,
                        num_rows: self.num_rows,
                        data,
                        is_square: false
                    }
                )
            }
        }
    }
    pub fn remove_row(self: &Self, row: u32) -> Result<Self, NmlError> {
        match row < self.num_rows {
            false => Err(NmlError::new(InvalidRows)),
            true => {
                let data: Vec<f64> = self.data[((row + 1) * self.num_cols) as usize..self.data.len()].to_vec();
                Ok(Self {
                    num_cols: self.num_cols,
                    num_rows: 1,
                    data,
                    is_square: false,
                })
            }
        }
    }
    pub fn get_sub_mtr(self: &Self, row_start: u32, row_end: u32, col_start: u32, col_end: u32) -> Result<Self, NmlError> {
        match row_start < self.num_rows && row_end < self.num_rows && col_start < self.num_cols && col_end < self.num_cols {
            false => Err(NmlError::new(InvalidRows)),
            true => {
                let mut data: Vec<f64> = Vec::new();
                for i in row_start - 1..row_end {
                    for j in col_start - 1..col_end {
                        data.push(self.data[(i * self.num_cols + j) as usize]);
                    }
                }
                Ok(Self {
                    num_rows: row_end - row_start,
                    num_cols: col_end - col_start,
                    data,
                    is_square: false,
                })
            }
        }
    }
    pub fn transpose(self: &Self) -> Self{
        let mut data: Vec<f64> = Vec::with_capacity(self.data.len());
        for i in 0..self.num_cols {
            for j in 0..self.num_rows {
                data.push(self.data[(i + j*self.num_cols) as usize]);
            }
        }
        Self {
            num_rows: self.num_cols,
            num_cols: self.num_rows,
            data,
            is_square: self.is_square,
        }
    }
    pub fn mul_transpose(self: &Self, other: &Self) -> Result<Self, NmlError> {
        match self.num_cols == other.num_rows {
            false => Err(NmlError::new(InvalidCols)),
            true => {
                let m: u32 = self.num_rows;
                let n: u32 = self.num_cols;
                let p: u32 = other.num_cols;
                let transpose: NmlMatrix = other.transpose();
                let mut data: Vec<f64> = Vec::new();
                for i in 0..m {
                    for j in 0..p {
                        data.insert((i * p + j) as usize, 0.0);
                        for k in 0..n {
                            data[(i*p+j) as usize] += self.data[(i * n + k) as usize] * transpose.data[(p * k + j) as usize];
                        }
                    }
                }
                Ok(Self{
                    num_rows: self.num_rows,
                    num_cols: other.num_cols,
                    data,
                    is_square: self.num_rows == other.num_cols
                })
            }
        }
    }
    pub fn mul_naive(self: &Self, other: &Self) -> Result<Self,NmlError> {
        let m = self.num_rows;
        let n_1 = self.num_cols;
        let n_2 = other.num_rows;
        let p = other.num_cols;
        match n_1 == n_2 {
            false => {Err(NmlError::new(CreateMatrix))},
            true => {
                let mut data = Vec::with_capacity((m*p) as usize);
                for i in 0..m {
                    for j in 0..p {
                        data.insert((i * p + j) as usize, 0.0);
                        for k in 0..n_1 {
                            data[(i*p+j) as usize] += self.data[(i * n_1 + k) as usize] * other.data[(p * k + j) as usize];
                        }
                    }
                }
                Ok(Self{
                    num_rows: m,
                    num_cols: p,
                    data,
                    is_square: m == p,
                })
            }
        }
    }
}
impl Sub for NmlMatrix{
    type Output = Result<Self, NmlError>;
    fn sub(self, rhs: Self) -> Self::Output {
        match self.num_rows == rhs.num_rows && self.num_cols == rhs.num_cols {
            false => Err(NmlError::new(CreateMatrix)),
            true => {
                let mut data: Vec<f64> = Vec::new();
                for i in 0..self.data.len() -1 {
                    data.push(self.data[i] - rhs.data[i]);
                }
                Ok(Self{
                    num_rows: self.num_rows,
                    num_cols: self.num_cols,
                    data,
                    is_square: self.is_square
                })
            }
        }
    }
}
impl Sub for &NmlMatrix {
    type Output = Result<NmlMatrix, NmlError>;
    fn sub(self, rhs: Self) -> Self::Output {
        match self.num_rows == rhs.num_rows && self.num_cols == rhs.num_cols {
            false => Err(NmlError::new(CreateMatrix)),
            true => {
                let mut data: Vec<f64> = Vec::new();
                for i in 0..self.data.len() -1 {
                    data.push(self.data[i] - rhs.data[i]);
                }
                Ok(NmlMatrix{
                    num_rows: self.num_rows,
                    num_cols: self.num_cols,
                    data,
                    is_square: self.is_square
                })
            }
        }
    }
}
impl Display for NmlMatrix {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let mut output = String::new();
        for i in 0..self.num_rows {
            for j in 0..self.num_cols {
                output.push_str(&self.data[(i * self.num_cols + j) as usize].to_string());
                output.push_str(" ");
            }
            output.push_str("\n");
        }
        write!(f, "{}", output)
    }
}
impl Eq for NmlMatrix {}
impl PartialEq for NmlMatrix {
    fn eq(&self, other: &Self) -> bool {
        self.equality(other)
    }
}
impl Add for NmlMatrix {
    type Output = Result<Self, NmlError>;
    fn add(self, rhs: Self) -> Self::Output {
        match self.num_rows == rhs.num_rows && self.num_cols == rhs.num_cols{
            false => Err(NmlError::new(CreateMatrix)),
            true => {
                let mut data: Vec<f64> = Vec::new();
                for i in 0..self.data.len() {
                    data.push(self.data[i] + rhs.data[i]);
                }
                Ok(Self{
                    num_cols: self.num_cols,
                    num_rows: self.num_rows,
                    data,
                    is_square: self.is_square
                })
            }
        }
    }
}
impl Add for &NmlMatrix {
    type Output = Result<NmlMatrix, NmlError>;
    fn add(self, rhs: Self) -> Self::Output {
        match self.num_rows == rhs.num_rows && self.num_cols == rhs.num_cols{
            false => Err(NmlError::new(CreateMatrix)),
            true => {
                let mut data: Vec<f64> = Vec::new();
                for i in 0..self.data.len() {
                    data.push(self.data[i] + rhs.data[i]);
                }
                Ok(NmlMatrix{
                    num_cols: self.num_cols,
                    num_rows: self.num_rows,
                    data,
                    is_square: self.is_square
                })
            }
        }
    }
}
impl Mul for NmlMatrix {
    type Output = Result<NmlMatrix, NmlError>;
    fn mul(self, rhs: Self) -> Self::Output {
        return self.mul_naive(&rhs);
    }
}