use std::{fmt::Debug, ops::Sub};
use ndarray::{Data, Dimension, RemoveAxis, Zip};
use num_traits::{Num, NumCast};
use crate::{interp1d::Linear, InterpolateError};
use super::{Interp2DStrategy, Interp2DStrategyBuilder};
#[derive(Debug)]
pub struct Biliniar {
    extrapolate: bool,
}
impl<Sd, Sx, Sy, D> Interp2DStrategyBuilder<Sd, Sx, Sy, D> for Biliniar
where
    Sd: Data,
    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub,
    Sx: Data<Elem = Sd::Elem>,
    Sy: Data<Elem = Sd::Elem>,
    D: Dimension + RemoveAxis,
    D::Smaller: RemoveAxis,
{
    const MINIMUM_DATA_LENGHT: usize = 2;
    type FinishedStrat = Self;
    fn build(
        self,
        _x: &ndarray::ArrayBase<Sx, ndarray::Ix1>,
        _y: &ndarray::ArrayBase<Sy, ndarray::Ix1>,
        _data: &ndarray::ArrayBase<Sd, D>,
    ) -> Result<Self::FinishedStrat, crate::BuilderError> {
        Ok(self)
    }
}
impl<Sd, Sx, Sy, D> Interp2DStrategy<Sd, Sx, Sy, D> for Biliniar
where
    Sd: Data,
    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub,
    Sx: Data<Elem = Sd::Elem>,
    Sy: Data<Elem = Sd::Elem>,
    D: Dimension + RemoveAxis,
    D::Smaller: RemoveAxis,
{
    fn interp_into(
        &self,
        interpolator: &crate::interp2d::Interp2D<Sd, Sx, Sy, D, Self>,
        target: ndarray::ArrayViewMut<'_, <Sd>::Elem, <D::Smaller as Dimension>::Smaller>,
        x: <Sx>::Elem,
        y: <Sy>::Elem,
    ) -> Result<(), crate::InterpolateError> {
        if !self.extrapolate && !interpolator.is_in_x_range(x) {
            return Err(InterpolateError::OutOfBounds(format!(
                "x = {x:?} is not in range"
            )));
        }
        if !self.extrapolate && !interpolator.is_in_y_range(y) {
            return Err(InterpolateError::OutOfBounds(format!(
                "y = {y:?} is not in range"
            )));
        }
        let (x_idx, y_idx) = interpolator.get_index_left_of(x, y);
        let (x1, y1, z11) = interpolator.index_point(x_idx, y_idx);
        let (_, _, z12) = interpolator.index_point(x_idx, y_idx + 1);
        let (_, _, z21) = interpolator.index_point(x_idx + 1, y_idx);
        let (x2, y2, z22) = interpolator.index_point(x_idx + 1, y_idx + 1);
        Zip::from(target)
            .and(z11)
            .and(z12)
            .and(z21)
            .and(z22)
            .for_each(|z, &z11, &z12, &z21, &z22| {
                let z1 = Linear::calc_frac((x1, z11), (x2, z21), x);
                let z2 = Linear::calc_frac((x1, z12), (x2, z22), x);
                *z = Linear::calc_frac((y1, z1), (y2, z2), y)
            });
        Ok(())
    }
}
impl Biliniar {
    pub fn new() -> Self {
        Biliniar { extrapolate: false }
    }
    pub fn extrapolate(mut self, yes: bool) -> Self {
        self.extrapolate = yes;
        self
    }
}
impl Default for Biliniar {
    fn default() -> Self {
        Self::new()
    }
}