Documentation
use alloc::boxed::Box;
use pin_project::pin_project;

use super::ring::Ring;
use crate::sync::{Closure, Discriminant, Flags, Signal, Waiter, Waiters, WaitersExt};
use core::{
    pin::{Pin, pin},
    sync::atomic::Ordering::*,
    task::{Context, Poll},
};

pub struct Duplex<A, B>(Half<A>, Half<B>);

pub struct Half<T> {
    ring: Ring<T>,
    senders: Waiters,
    receivers: Waiters,
    is_closed: Flags<Closure>,
}

pub struct Endpoint<S, R> {
    send_half: *const Half<S>,
    recv_half: *const Half<R>,
}

unsafe impl<S: Send, R: Send> Send for Endpoint<S, R> {}
unsafe impl<S: Send, R: Send> Sync for Endpoint<S, R> {}

impl<T> Half<T> {
    fn new(capacity: usize) -> Self {
        Self {
            ring: Ring::new(capacity),
            senders: Waiters::new(),
            receivers: Waiters::new(),
            is_closed: Flags::default(),
        }
    }

    unsafe fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
        if self.is_closed.is_set(&Closure::Closed) {
            return Err(TrySendError::Disconnected(value));
        }

        match self.ring.push(value) {
            Ok(()) => {
                if let Some(waiter) = self.receivers.dequeue() {
                    waiter.wake_by_ref();
                }
                Ok(())
            }
            Err(value) => Err(TrySendError::Full(value)),
        }
    }

    unsafe fn try_recv(&self) -> Result<T, TryRecvError> {
        if let Some(value) = self.ring.pop() {
            if let Some(waiter) = self.senders.dequeue() {
                waiter.wake_by_ref();
            }
            return Ok(value);
        }

        if self.is_closed.is_set(&Closure::Closed) {
            Err(TryRecvError::Disconnected)
        } else {
            Err(TryRecvError::Empty)
        }
    }

    unsafe fn close(&self) {
        self.is_closed.set(&Closure::Closed);
        self.receivers.notify_all();
        self.senders.notify_all();
    }
}

impl<S: Send + Unpin, R: Send> Endpoint<S, R> {
    pub fn try_send(&self, value: S) -> Result<(), TrySendError<S>> {
        unsafe { (*self.send_half).try_send(value) }
    }

    pub fn try_recv(&self) -> Result<R, TryRecvError> {
        unsafe { (*self.recv_half).try_recv() }
    }

    pub async fn send(&self, value: S) -> Result<(), S> {
        SendFuture {
            half: self.send_half,
            value: Some(value),
            waiter: None,
        }
        .await
    }

    pub async fn recv(&self) -> Option<R> {
        RecvFuture {
            half: self.recv_half,
        }
        .await
    }

    pub fn close(&self) {
        unsafe {
            (*self.send_half).close();
            (*self.recv_half).close();
        }
    }
}

#[pin_project]
struct SendFuture<T> {
    half: *const Half<T>,
    #[pin]
    value: Option<T>,
    #[pin]
    waiter: Option<Waiter>,
}

// Update the send method to initialize with waiter field
impl<T: Send> Half<T> {
    pub fn send(&self, value: T) -> SendFuture<T> {
        SendFuture {
            half: self as *const _,
            value: Some(value),
            waiter: None, // Add this field
        }
    }
}
impl<T: Send + Unpin> Future for SendFuture<T> {
    type Output = Result<(), T>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut this = self.project();

        // Use get_mut() to get &mut Option<T> from Pin<&mut Option<T>>
        if let Some(value) = this.value.as_mut().get_mut().take() {
            unsafe {
                let half = (*this.half).as_ref().unwrap();
                match half.try_send(value) {
                    Ok(()) => Poll::Ready(Ok(())),
                    Err(TrySendError::Disconnected(value)) => Poll::Ready(Err(value)),
                    Err(TrySendError::Full(value)) => {
                        // Create waiter if we don't have one
                        if this.waiter.is_none() {
                            this.waiter
                                .set(Some(Waiter::from_waker(cx.waker().clone())));
                        } else {
                            // Update existing waiter's waker
                            this.waiter
                                .as_mut()
                                .as_pin_mut()
                                .unwrap()
                                .assign_waker(cx.waker().clone());
                        }

                        // Enqueue the waiter - need to get a &Waiter
                        let waiter_ref = this.waiter.take().unwrap();
                        half.senders.enqueue(waiter_ref);

                        // Put the value back using get_mut()
                        this.value = pin!(Some(value));
                        Poll::Pending
                    }
                }
            }
        } else {
            panic!("SendFuture polled after completion");
        }
    }
}

struct RecvFuture<T> {
    half: *const Half<T>,
}

impl<T: Send> Future for RecvFuture<T> {
    type Output = Option<T>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        unsafe {
            match (*self.half).try_recv() {
                Ok(value) => Poll::Ready(Some(value)),
                Err(TryRecvError::Disconnected) => Poll::Ready(None),
                Err(TryRecvError::Empty) => {
                    (*self.half)
                        .receivers
                        .enqueue(Waiter::from_waker(cx.waker().clone()));
                    Poll::Pending
                }
            }
        }
    }
}

#[derive(Debug)]
pub enum TrySendError<T> {
    Full(T),
    Disconnected(T),
}

#[derive(Debug)]
pub enum TryRecvError {
    Empty,
    Disconnected,
}

pub fn duplex<A: Send, B: Send>(capacity: usize) -> (Endpoint<A, B>, Endpoint<B, A>) {
    let duplex = Box::new(Duplex(Half::new(capacity), Half::new(capacity)));

    let ptr = Box::into_raw(duplex);

    let a = Endpoint {
        send_half: unsafe { &(*ptr).0 },
        recv_half: unsafe { &(*ptr).1 },
    };

    let b = Endpoint {
        send_half: unsafe { &(*ptr).1 },
        recv_half: unsafe { &(*ptr).0 },
    };

    (a, b)
}