use core::convert::Infallible;
use core::ops::Deref;
pub use embedded_hal::spi::{Mode, Phase, Polarity, MODE_0, MODE_1, MODE_2, MODE_3};
use crate::e310x::{QSPI0, QSPI1, QSPI2, qspi0};
use crate::clock::Clocks;
use crate::time::Hertz;
use nb;
pub trait Pins<SPI> {
    #[doc(hidden)]
    const CS_INDEX: Option<u32>;
}
impl Pins<QSPI0> for () {
    const CS_INDEX: Option<u32> = Some(0);
}
mod spi1_impl {
    use crate::gpio::{NoInvert, IOF0};
    use crate::gpio::gpio0;
    use super::{Pins, QSPI1};
    type MOSI = gpio0::Pin3<IOF0<NoInvert>>;
    type MISO = gpio0::Pin4<IOF0<NoInvert>>;
    type SCK = gpio0::Pin5<IOF0<NoInvert>>;
    type SS0 = gpio0::Pin2<IOF0<NoInvert>>;
    type SS1 = gpio0::Pin8<IOF0<NoInvert>>;
    type SS2 = gpio0::Pin9<IOF0<NoInvert>>;
    type SS3 = gpio0::Pin10<IOF0<NoInvert>>;
    impl Pins<QSPI1> for (MOSI, MISO, SCK) { const CS_INDEX: Option<u32> = None; }
    impl Pins<QSPI1> for (MOSI, (),   SCK) { const CS_INDEX: Option<u32> = None; }
    impl Pins<QSPI1> for ((),   MISO, SCK) { const CS_INDEX: Option<u32> = None; }
    impl Pins<QSPI1> for (MOSI, MISO, SCK, SS0) { const CS_INDEX: Option<u32> = Some(0); }
    impl Pins<QSPI1> for (MOSI, (),   SCK, SS0) { const CS_INDEX: Option<u32> = Some(0); }
    impl Pins<QSPI1> for ((),   MISO, SCK, SS0) { const CS_INDEX: Option<u32> = Some(0); }
    impl Pins<QSPI1> for (MOSI, MISO, SCK, SS1) { const CS_INDEX: Option<u32> = Some(1); }
    impl Pins<QSPI1> for (MOSI, (),   SCK, SS1) { const CS_INDEX: Option<u32> = Some(1); }
    impl Pins<QSPI1> for ((),   MISO, SCK, SS1) { const CS_INDEX: Option<u32> = Some(1); }
    impl Pins<QSPI1> for (MOSI, MISO, SCK, SS2) { const CS_INDEX: Option<u32> = Some(2); }
    impl Pins<QSPI1> for (MOSI, (),   SCK, SS2) { const CS_INDEX: Option<u32> = Some(2); }
    impl Pins<QSPI1> for ((),   MISO, SCK, SS2) { const CS_INDEX: Option<u32> = Some(2); }
    impl Pins<QSPI1> for (MOSI, MISO, SCK, SS3) { const CS_INDEX: Option<u32> = Some(3); }
    impl Pins<QSPI1> for (MOSI, (),   SCK, SS3) { const CS_INDEX: Option<u32> = Some(3); }
    impl Pins<QSPI1> for ((),   MISO, SCK, SS3) { const CS_INDEX: Option<u32> = Some(3); }
}
mod spi2_impl {
    use crate::gpio::{NoInvert, IOF0};
    use crate::gpio::gpio0;
    use super::{Pins, QSPI2};
    type MOSI = gpio0::Pin27<IOF0<NoInvert>>;
    type MISO = gpio0::Pin28<IOF0<NoInvert>>;
    type SCK = gpio0::Pin29<IOF0<NoInvert>>;
    type SS0 = gpio0::Pin26<IOF0<NoInvert>>;
    impl Pins<QSPI2> for (MOSI, MISO, SCK) { const CS_INDEX: Option<u32> = None; }
    impl Pins<QSPI2> for (MOSI, (),   SCK) { const CS_INDEX: Option<u32> = None; }
    impl Pins<QSPI2> for ((),   MISO, SCK) { const CS_INDEX: Option<u32> = None; }
    impl Pins<QSPI2> for (MOSI, MISO, SCK, SS0) { const CS_INDEX: Option<u32> = Some(0); }
    impl Pins<QSPI2> for (MOSI, (),   SCK, SS0) { const CS_INDEX: Option<u32> = Some(0); }
    impl Pins<QSPI2> for ((),   MISO, SCK, SS0) { const CS_INDEX: Option<u32> = Some(0); }
}
#[doc(hidden)]
pub trait SpiX: Deref<Target = qspi0::RegisterBlock> {}
impl SpiX for QSPI0 {}
impl SpiX for QSPI1 {}
impl SpiX for QSPI2 {}
pub struct Spi<SPI, PINS> {
    spi: SPI,
    pins: PINS,
}
impl<SPI: SpiX, PINS> Spi<SPI, PINS> {
    
    
    pub fn new(spi: SPI, pins: PINS, mode: Mode, freq: Hertz, clocks: Clocks) -> Self
    where
        PINS: Pins<SPI>
    {
        let div = clocks.tlclk().0 / (2 * freq.0) - 1;
        spi.div.write(|w| unsafe { w.bits(div) });
        let cs_mode = if let Some(cs_index) = PINS::CS_INDEX {
            spi.csid.write(|w| unsafe { w.bits(cs_index) });
            2 
        } else {
            3 
        };
        spi.csmode.write(|w| unsafe { w.bits(cs_mode) });
        
        spi.csdef.reset();
        
        let phase = mode.phase == Phase::CaptureOnSecondTransition;
        let polarity = mode.polarity == Polarity::IdleHigh;
        spi.mode.write(|w| w
            .phase().bit(phase)
            .polarity().bit(polarity)
        );
        spi.fmt.write(|w| unsafe { w
            .protocol().bits(0) 
            .endian().clear_bit() 
            .direction().rx()
            .length().bits(8)
        });
        
        spi.txmark.write(|w| unsafe { w.value().bits(1) });
        spi.rxmark.write(|w| unsafe { w.value().bits(0) });
        spi.delay0.reset();
        spi.delay1.reset();
        Self { spi, pins }
    }
    
    pub fn set_tx_watermark(&mut self, value: u8) {
        self.spi.txmark.write(|w| unsafe { w.value().bits(value) });
    }
    
    pub fn set_rx_watermark(&mut self, value: u8) {
        self.spi.rxmark.write(|w| unsafe { w.value().bits(value) });
    }
    
    pub fn tx_wm_is_pending(&self) -> bool {
        self.spi.ip.read().txwm().bit()
    }
    
    pub fn rx_wm_is_pending(&self) -> bool {
        self.spi.ip.read().rxwm().bit()
    }
    
    pub fn listen_tx_wm(&mut self) {
        self.spi.ie.write(|w| w.txwm().set_bit())
    }
    
    pub fn listen_rx_wm(&mut self) {
        self.spi.ie.write(|w| w.rxwm().set_bit())
    }
    
    pub fn unlisten_tx_wm(&mut self) {
        self.spi.ie.write(|w| w.txwm().clear_bit())
    }
    
    pub fn unlisten_rx_wm(&mut self) {
        self.spi.ie.write(|w| w.rxwm().clear_bit())
    }
    
    pub fn cs_mode_word(&mut self) {
        if self.spi.csmode.read().bits() != 3 {
            self.spi.csmode.write(|w| unsafe { w.bits(0) });
        }
    }
    
    pub fn cs_mode_frame(&mut self) {
        if self.spi.csmode.read().bits() != 3 {
            self.spi.csmode.write(|w| unsafe { w.bits(2) });
        }
    }
    
    pub fn end_transfer(&mut self) {
        self.cs_mode_word()
    }
    
    pub fn free(self) -> (SPI, PINS) {
        (self.spi, self.pins)
    }
}
impl<SPI: SpiX, PINS> embedded_hal::spi::FullDuplex<u8> for Spi<SPI, PINS> {
    type Error = Infallible;
    fn read(&mut self) -> nb::Result<u8, Infallible> {
        let rxdata = self.spi.rxdata.read();
        if rxdata.empty().bit_is_set() {
            Err(nb::Error::WouldBlock)
        } else {
            Ok(rxdata.data().bits())
        }
    }
    fn send(&mut self, byte: u8) -> nb::Result<(), Infallible> {
        let txdata = self.spi.txdata.read();
        if txdata.full().bit_is_set() {
            Err(nb::Error::WouldBlock)
        } else {
            self.spi.txdata.write(|w| unsafe { w.data().bits(byte) });
            Ok(())
        }
    }
}
impl<SPI: SpiX, PINS> embedded_hal::blocking::spi::Transfer<u8> for Spi<SPI, PINS> {
    type Error = Infallible;
    fn transfer<'w>(&mut self, words: &'w mut [u8]) -> Result<&'w[u8], Self::Error> {
        
        let txmark = self.spi.txmark.read().value().bits();
        let rxmark = self.spi.rxmark.read().value().bits();
        
        self.spi.txmark.write(|w| unsafe { w.value().bits(1) });
        self.spi.rxmark.write(|w| unsafe { w.value().bits(0) });
        
        while self.spi.ip.read().rxwm().bit_is_set() {
            let _ = self.spi.rxdata.read();
        }
        self.cs_mode_frame();
        let mut iwrite = 0;
        let mut iread = 0;
        while iwrite < words.len() || iread < words.len() {
            if iwrite < words.len() && self.spi.txdata.read().full().bit_is_clear() {
                let byte = unsafe { words.get_unchecked(iwrite) };
                iwrite += 1;
                self.spi.txdata.write(|w| unsafe { w.data().bits(*byte) });
            }
            if iread < iwrite && self.spi.ip.read().rxwm().bit_is_set() {
                let byte = self.spi.rxdata.read().data().bits();
                unsafe { *words.get_unchecked_mut(iread) = byte };
                iread += 1;
            }
        }
        self.cs_mode_word();
        
        self.spi.txmark.write(|w| unsafe { w.value().bits(txmark) });
        self.spi.rxmark.write(|w| unsafe { w.value().bits(rxmark) });
        Ok(words)
    }
}
impl<SPI: SpiX, PINS> embedded_hal::blocking::spi::Write<u8> for Spi<SPI, PINS> {
    type Error = Infallible;
    fn write(&mut self, words: &[u8]) -> Result<(), Self::Error> {
        
        let txmark = self.spi.txmark.read().value().bits();
        let rxmark = self.spi.rxmark.read().value().bits();
        
        self.spi.txmark.write(|w| unsafe { w.value().bits(1) });
        self.spi.rxmark.write(|w| unsafe { w.value().bits(0) });
        
        while self.spi.ip.read().rxwm().bit_is_set() {
            let _ = self.spi.rxdata.read();
        }
        self.cs_mode_frame();
        let mut iwrite = 0;
        let mut iread = 0;
        while iwrite < words.len() || iread < words.len() {
            if iwrite < words.len() && self.spi.txdata.read().full().bit_is_clear() {
                let byte = unsafe { words.get_unchecked(iwrite) };
                iwrite += 1;
                self.spi.txdata.write(|w| unsafe { w.data().bits(*byte) });
            }
            if iread < iwrite && self.spi.ip.read().rxwm().bit_is_set() {
                let _ = self.spi.rxdata.read();
                iread += 1;
            }
        }
        self.cs_mode_word();
        
        self.spi.txmark.write(|w| unsafe { w.value().bits(txmark) });
        self.spi.rxmark.write(|w| unsafe { w.value().bits(rxmark) });
        Ok(())
    }
}
impl<SPI: SpiX, PINS> embedded_hal::blocking::spi::WriteIter<u8> for Spi<SPI, PINS> {
    type Error = Infallible;
    fn write_iter<WI>(&mut self, words: WI) -> Result<(), Self::Error>
    where
        WI: IntoIterator<Item=u8>
    {
        let mut iter = words.into_iter();
        
        let txmark = self.spi.txmark.read().value().bits();
        let rxmark = self.spi.rxmark.read().value().bits();
        
        self.spi.txmark.write(|w| unsafe { w.value().bits(1) });
        self.spi.rxmark.write(|w| unsafe { w.value().bits(0) });
        
        while self.spi.ip.read().rxwm().bit_is_set() {
            let _ = self.spi.rxdata.read();
        }
        self.cs_mode_frame();
        let mut read_count = 0;
        let mut has_data = true;
        while has_data || read_count > 0 {
            if has_data && self.spi.txdata.read().full().bit_is_clear() {
                if let Some(byte) = iter.next() {
                    self.spi.txdata.write(|w| unsafe { w.data().bits(byte) });
                    read_count += 1;
                } else {
                    has_data = false;
                }
            }
            if read_count > 0 && self.spi.ip.read().rxwm().bit_is_set() {
                let _ = self.spi.rxdata.read();
                read_count -= 1;
            }
        }
        self.cs_mode_word();
        
        self.spi.txmark.write(|w| unsafe { w.value().bits(txmark) });
        self.spi.rxmark.write(|w| unsafe { w.value().bits(rxmark) });
        Ok(())
    }
}
impl<PINS> Spi<QSPI0, PINS> {
    
    #[deprecated(note = "Please use Spi::new function instead")]
    pub fn spi0(spi: QSPI0, pins: PINS, mode: Mode, freq: Hertz, clocks: Clocks) -> Self
    where
        PINS: Pins<QSPI0>
    {
        Self::new(spi, pins, mode, freq, clocks)
    }
}
impl<PINS> Spi<QSPI1, PINS> {
    
    #[deprecated(note = "Please use Spi::new function instead")]
    pub fn spi1(spi: QSPI1, pins: PINS, mode: Mode, freq: Hertz, clocks: Clocks) -> Self
        where
            PINS: Pins<QSPI1>
    {
        Self::new(spi, pins, mode, freq, clocks)
    }
}
impl<PINS> Spi<QSPI2, PINS> {
    
    #[deprecated(note = "Please use Spi::new function instead")]
    pub fn spi2(spi: QSPI2, pins: PINS, mode: Mode, freq: Hertz, clocks: Clocks) -> Self
        where
            PINS: Pins<QSPI2>
    {
        Self::new(spi, pins, mode, freq, clocks)
    }
}