use crate::float_trait::Float;
use crate::sorted_array::SortedArray;
use conv::{ConvAsUtil, ConvUtil, RoundToNearest};
use enum_dispatch::enum_dispatch;
use itertools::Itertools;
use macro_const::macro_const;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
macro_const! {
    const NYQUIST_FREQ_DOC: &'static str = r"Derive Nyquist frequency from time series
Nyquist frequency for unevenly time series is not well-defined. Here we define it as
$\pi / \delta t$, where $\delta t$ is some typical interval between consequent observations
";
}
#[doc = NYQUIST_FREQ_DOC!()]
#[enum_dispatch]
trait NyquistFreqTrait: Send + Sync + Clone + Debug {
    fn nyquist_freq<T: Float>(&self, t: &[T]) -> T;
}
#[doc = NYQUIST_FREQ_DOC!()]
#[enum_dispatch(NyquistFreqTrait)]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[non_exhaustive]
pub enum NyquistFreq {
    Average(AverageNyquistFreq),
    Median(MedianNyquistFreq),
    Quantile(QuantileNyquistFreq),
    Fixed(FixedNyquistFreq),
}
impl NyquistFreq {
    pub fn average() -> Self {
        Self::Average(AverageNyquistFreq)
    }
    pub fn median() -> Self {
        Self::Median(MedianNyquistFreq)
    }
    pub fn quantile(quantile: f32) -> Self {
        Self::Quantile(QuantileNyquistFreq { quantile })
    }
    pub fn fixed(freq: f32) -> Self {
        Self::Fixed(FixedNyquistFreq(freq))
    }
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename = "Average")]
pub struct AverageNyquistFreq;
impl NyquistFreqTrait for AverageNyquistFreq {
    fn nyquist_freq<T: Float>(&self, t: &[T]) -> T {
        let n = t.len();
        T::PI() * (n - 1).value_as().unwrap() / (t[n - 1] - t[0])
    }
}
fn diff<T: Float>(x: &[T]) -> Vec<T> {
    x.iter().tuple_windows().map(|(&a, &b)| b - a).collect()
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename = "Median")]
pub struct MedianNyquistFreq;
impl NyquistFreqTrait for MedianNyquistFreq {
    fn nyquist_freq<T: Float>(&self, t: &[T]) -> T {
        let sorted_dt: SortedArray<_> = diff(t).into();
        let dt = sorted_dt.median();
        T::PI() / dt
    }
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename = "Quantile")]
pub struct QuantileNyquistFreq {
    pub quantile: f32,
}
impl NyquistFreqTrait for QuantileNyquistFreq {
    fn nyquist_freq<T: Float>(&self, t: &[T]) -> T {
        let sorted_dt: SortedArray<_> = diff(t).into();
        let dt = sorted_dt.ppf(self.quantile);
        T::PI() / dt
    }
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename = "Fixed")]
pub struct FixedNyquistFreq(pub f32);
impl FixedNyquistFreq {
    pub fn from_dt<T: Float>(dt: T) -> Self {
        let dt: f32 = dt.approx().unwrap();
        assert!(dt > 0.0);
        Self(core::f32::consts::PI / dt)
    }
}
impl NyquistFreqTrait for FixedNyquistFreq {
    fn nyquist_freq<T: Float>(&self, _t: &[T]) -> T {
        self.0.value_as().unwrap()
    }
}
#[derive(Clone, Debug)]
pub struct FreqGrid<T> {
    pub step: T,
    pub size: usize,
}
impl<T> FreqGrid<T>
where
    T: Float,
{
    pub fn from_t(t: &[T], resolution: f32, max_freq_factor: f32, nyquist: NyquistFreq) -> Self {
        assert!(resolution.is_sign_positive() && resolution.is_finite());
        let sizef: T = t.len().approx().unwrap();
        let duration = t[t.len() - 1] - t[0];
        let step = T::two() * T::PI() * (sizef - T::one())
            / (sizef * resolution.value_as::<T>().unwrap() * duration);
        let max_freq = nyquist.nyquist_freq(t) * max_freq_factor.value_as::<T>().unwrap();
        let size = (max_freq / step).approx_by::<RoundToNearest>().unwrap();
        Self { step, size }
    }
}