use super::{element::FieldElement, errors::FieldError};
#[cfg(feature = "lambdaworks-serde-binary")]
use crate::traits::ByteConversion;
use crate::{errors::CreationError, unsigned_integer::traits::IsUnsignedInteger};
use core::fmt::Debug;
#[derive(Clone, Copy)]
pub enum RootsConfig {
    Natural,            NaturalInversed,    BitReverse,         BitReverseInversed, }
pub trait IsSubFieldOf<F: IsField>: IsField {
    fn mul(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType;
    fn add(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType;
    fn div(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType;
    fn sub(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType;
    fn embed(a: Self::BaseType) -> F::BaseType;
    #[cfg(feature = "alloc")]
    fn to_subfield_vec(b: F::BaseType) -> alloc::vec::Vec<Self::BaseType>;
}
impl<F> IsSubFieldOf<F> for F
where
    F: IsField,
{
    #[inline(always)]
    fn mul(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType {
        F::mul(a, b)
    }
    #[inline(always)]
    fn add(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType {
        F::add(a, b)
    }
    #[inline(always)]
    fn sub(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType {
        F::sub(a, b)
    }
    #[inline(always)]
    fn div(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType {
        F::div(a, b)
    }
    #[inline(always)]
    fn embed(a: Self::BaseType) -> F::BaseType {
        a
    }
    #[cfg(feature = "alloc")]
    fn to_subfield_vec(b: F::BaseType) -> alloc::vec::Vec<Self::BaseType> {
        alloc::vec![b]
    }
}
pub trait IsFFTField: IsField {
    const TWO_ADICITY: u64;
    const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: Self::BaseType;
    fn field_name() -> &'static str {
        ""
    }
    fn get_primitive_root_of_unity(order: u64) -> Result<FieldElement<Self>, FieldError> {
        let two_adic_primitive_root_of_unity =
            FieldElement::new(Self::TWO_ADIC_PRIMITVE_ROOT_OF_UNITY);
        if order == 0 {
            return Ok(FieldElement::one());
        }
        if order > Self::TWO_ADICITY {
            return Err(FieldError::RootOfUnityError(order));
        }
        let log_power = Self::TWO_ADICITY - order;
        let root = (0..log_power).fold(two_adic_primitive_root_of_unity, |acc, _| acc.square());
        Ok(root)
    }
}
pub trait IsField: Debug + Clone {
    #[cfg(feature = "lambdaworks-serde-binary")]
    type BaseType: Clone + Debug + Unpin + ByteConversion;
    #[cfg(not(feature = "lambdaworks-serde-binary"))]
    type BaseType: Clone + Debug + Unpin;
    fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType;
    fn double(a: &Self::BaseType) -> Self::BaseType {
        Self::add(a, a)
    }
    fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType;
    fn square(a: &Self::BaseType) -> Self::BaseType {
        Self::mul(a, a)
    }
    fn pow<T>(a: &Self::BaseType, mut exponent: T) -> Self::BaseType
    where
        T: IsUnsignedInteger,
    {
        let zero = T::from(0);
        let one = T::from(1);
        if exponent == zero {
            Self::one()
        } else if exponent == one {
            a.clone()
        } else {
            let mut result = a.clone();
            while exponent & one == zero {
                result = Self::square(&result);
                exponent >>= 1;
            }
            if exponent == zero {
                result
            } else {
                let mut base = result.clone();
                exponent >>= 1;
                while exponent != zero {
                    base = Self::square(&base);
                    if exponent & one == one {
                        result = Self::mul(&result, &base);
                    }
                    exponent >>= 1;
                }
                result
            }
        }
    }
    fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType;
    fn neg(a: &Self::BaseType) -> Self::BaseType;
    fn inv(a: &Self::BaseType) -> Result<Self::BaseType, FieldError>;
    fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType;
    fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool;
    fn zero() -> Self::BaseType;
    fn one() -> Self::BaseType;
    fn from_u64(x: u64) -> Self::BaseType;
    fn from_base_type(x: Self::BaseType) -> Self::BaseType;
}
#[derive(PartialEq)]
pub enum LegendreSymbol {
    MinusOne,
    Zero,
    One,
}
pub trait IsPrimeField: IsField {
    type RepresentativeType: IsUnsignedInteger;
    fn representative(a: &Self::BaseType) -> Self::RepresentativeType;
    fn modulus_minus_one() -> Self::RepresentativeType {
        Self::representative(&Self::neg(&Self::one()))
    }
    fn from_hex(hex_string: &str) -> Result<Self::BaseType, CreationError>;
    #[cfg(feature = "std")]
    fn to_hex(a: &Self::BaseType) -> String;
    fn field_bit_size() -> usize;
    fn legendre_symbol(a: &Self::BaseType) -> LegendreSymbol {
        let symbol = Self::pow(a, Self::modulus_minus_one() >> 1);
        match symbol {
            x if Self::eq(&x, &Self::zero()) => LegendreSymbol::Zero,
            x if Self::eq(&x, &Self::one()) => LegendreSymbol::One,
            _ => LegendreSymbol::MinusOne,
        }
    }
    fn sqrt(a: &Self::BaseType) -> Option<(Self::BaseType, Self::BaseType)> {
        match Self::legendre_symbol(a) {
            LegendreSymbol::Zero => return Some((Self::zero(), Self::zero())),
            LegendreSymbol::MinusOne => return None,
            LegendreSymbol::One => (),
        };
        let integer_one = Self::RepresentativeType::from(1_u16);
        let mut s: usize = 0;
        let mut q = Self::modulus_minus_one();
        while q & integer_one != integer_one {
            s += 1;
            q >>= 1;
        }
        let mut c = {
            let mut non_qr = Self::from_u64(2);
            while Self::legendre_symbol(&non_qr) != LegendreSymbol::MinusOne {
                non_qr = Self::add(&non_qr, &Self::one());
            }
            Self::pow(&non_qr, q)
        };
        let mut x = Self::pow(a, (q + integer_one) >> 1);
        let mut t = Self::pow(a, q);
        let mut m = s;
        let one = Self::one();
        while !Self::eq(&t, &one) {
            let i = {
                let mut i = 0;
                let mut t = t.clone();
                let minus_one = Self::neg(&Self::one());
                while !Self::eq(&t, &minus_one) {
                    i += 1;
                    t = Self::mul(&t, &t);
                }
                i + 1
            };
            let b = (0..(m - i - 1)).fold(c, |acc, _| Self::square(&acc));
            c = Self::mul(&b, &b);
            x = Self::mul(&x, &b);
            t = Self::mul(&t, &c);
            m = i;
        }
        let neg_x = Self::neg(&x);
        Some((x, neg_x))
    }
}