Documentation
use std::{
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
    time::Duration,
};

use crate::error::{SendError, SendTimeoutError, TrySendError};
use futures::task::AtomicWaker;
use std::future::Future;

use crate::{inner::Inner, util::async_send};

// static SEND_ID: AtomicUsize = AtomicUsize::new(0);

/// The sending half of a channel.
///
/// This struct allows sending values both synchronously and asynchronously.
/// It can be cloned to create multiple senders for the same channel.
///
/// # Examples
///
/// ```rust
/// use linch::bounded;
///
/// let (sender, receiver) = bounded(5);
///
/// // Send synchronously
/// sender.send(42).unwrap();
///
/// // Clone the sender
/// let sender2 = sender.clone();
/// sender2.send(43).unwrap();
/// ```
pub struct Sender<T> {
    pub(crate) tx: crossbeam_channel::Sender<T>,
    inner: Arc<Inner>,
}

impl<T> Clone for Sender<T> {
    fn clone(&self) -> Self {
        self.inner.inc_tx();

        Self {
            tx: self.tx.clone(),
            inner: self.inner.clone(),
        }
    }
}

impl<T> Drop for Sender<T> {
    fn drop(&mut self) {
        if self.inner.dec_tx() == 1 {
            let mut signal_queues = self.inner.signal_queues();
            while let Some(waker) = signal_queues.pop_recv() {
                waker.as_ref().wake();
            }

            while let Some(waker) = signal_queues.pop_send() {
                waker.as_ref().wake();
            }
        }
    }
}

impl<T> Sender<T> {
    /// Creates a new sender from the underlying crossbeam channel and inner state.
    ///
    /// This is typically not called directly by users, but rather through [`channel`](crate::channel).
    pub(crate) fn new(tx: crossbeam_channel::Sender<T>, inner: Arc<Inner>) -> Self {
        Self { tx, inner }
    }

    /// Sends a value synchronously.
    ///
    /// This method blocks until there is space in the channel buffer or until
    /// all receivers have been dropped.
    ///
    /// # Arguments
    ///
    /// * `value` - The value to send
    ///
    /// # Returns
    ///
    /// * `Ok(())` if the value was sent successfully
    /// * `Err(SendError(value))` if all receivers have been dropped
    ///
    /// # Examples
    ///
    /// ```rust
    /// use linch::bounded;
    ///
    /// let (sender, receiver) = bounded(1);
    /// sender.send(42).unwrap();
    /// assert_eq!(receiver.recv().unwrap(), 42);
    /// ```
    pub fn send(&self, value: T) -> Result<(), SendError<T>> {
        let res = self.tx.send(value);
        if res.is_ok() {
            self.signal_recv();
        }

        Ok(res?)
    }

    /// Sends a value synchronously with a timeout.
    ///
    /// This method blocks until there is space in the channel buffer, the timeout
    /// expires, or all receivers have been dropped.
    ///
    /// # Arguments
    ///
    /// * `value` - The value to send
    /// * `timeout` - The maximum duration to wait
    ///
    /// # Returns
    ///
    /// * `Ok(())` if the value was sent successfully
    /// * `Err(SendTimeoutError::Timeout(value))` if the timeout expired
    /// * `Err(SendTimeoutError::Disconnected(value))` if all receivers have been dropped
    ///
    /// # Examples
    ///
    /// ```rust
    /// use linch::bounded;
    /// use std::time::Duration;
    ///
    /// let (sender, _receiver) = bounded(1);
    /// sender.send(1).unwrap(); // Fill the buffer
    ///
    /// // This will timeout since the buffer is full
    /// let result = sender.send_timeout(2, Duration::from_millis(10));
    /// assert!(result.is_err());
    /// ```
    pub fn send_timeout(&self, value: T, timeout: Duration) -> Result<(), SendTimeoutError<T>> {
        let res = self.tx.send_timeout(value, timeout);
        if res.is_ok() {
            self.signal_recv();
        }

        Ok(res?)
    }

    /// Attempts to send a value without blocking.
    ///
    /// This method will either send the value immediately if there is space in the
    /// channel buffer, or return an error without blocking.
    ///
    /// # Arguments
    ///
    /// * `value` - The value to send
    ///
    /// # Returns
    ///
    /// * `Ok(())` if the value was sent successfully
    /// * `Err(TrySendError::Full(value))` if the channel buffer is full
    /// * `Err(TrySendError::Disconnected(value))` if all receivers have been dropped
    ///
    /// # Examples
    ///
    /// ```rust
    /// use linch::bounded;
    /// use linch::TrySendError;
    ///
    /// let (sender, receiver) = bounded(1);
    ///
    /// // First send succeeds
    /// assert!(sender.try_send(1).is_ok());
    ///
    /// // Second send fails because buffer is full
    /// match sender.try_send(2) {
    ///     Err(TrySendError::Full(value)) => {
    ///         assert_eq!(value, 2);
    ///     }
    ///     _ => panic!("Expected Full error"),
    /// }
    ///
    /// // After receiving, we can send again
    /// receiver.recv().unwrap();
    /// assert!(sender.try_send(2).is_ok());
    /// ```
    pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
        let res = self.tx.try_send(value);
        if res.is_ok() {
            self.signal_recv();
        }

        Ok(res?)
    }

    #[inline(always)]
    pub(crate) fn signal_recv(&self) {
        if let Some(waker) = { self.inner.signal_queues().pop_recv() } {
            waker.as_ref().wake();
        }
    }

    /// Sends a value asynchronously.
    ///
    /// This method returns a future that will complete when there is space in the
    /// channel buffer or when all receivers have been dropped.
    ///
    /// # Arguments
    ///
    /// * `value` - The value to send
    ///
    /// # Returns
    ///
    /// A [`SendFut`] future that resolves to:
    /// * `Ok(())` if the value was sent successfully
    /// * `Err(SendError(value))` if all receivers have been dropped
    ///
    /// # Examples
    ///
    /// ```rust
    /// use linch::bounded;
    ///
    /// # tokio_test::block_on(async {
    /// let (sender, receiver) = bounded(1);
    /// sender.send_async(42).await.unwrap();
    /// assert_eq!(receiver.recv().unwrap(), 42);
    /// # });
    /// ```
    pub fn send_async(&self, value: T) -> SendFut<'_, T> {
        SendFut {
            tx: &self.tx,
            inner: &self.inner,
            value: Some(value),
            poll_cnt: 0,
            waker: AtomicWaker::new(),
        }
    }
}

/// A future representing an asynchronous send operation.
///
/// This future is created by the [`send_async`](Sender::send_async) method and will
/// complete when the value has been sent or when all receivers have been dropped.
///
/// # Examples
///
/// ```rust
/// use linch::bounded;
///
/// # tokio_test::block_on(async {
/// let (sender, receiver) = bounded(1);
/// let send_fut = sender.send_async(42);
/// send_fut.await.unwrap();
/// assert_eq!(receiver.recv().unwrap(), 42);
/// # });
/// ```
pub struct SendFut<'a, T> {
    tx: &'a crossbeam_channel::Sender<T>,
    inner: &'a Arc<Inner>,
    value: Option<T>,
    poll_cnt: u32,
    waker: AtomicWaker,
}

impl<'a, T> Unpin for SendFut<'a, T> {}

impl<'a, T> Future for SendFut<'a, T> {
    type Output = Result<(), SendError<T>>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.get_mut();

        for _ in 0..1 {
            match async_send(this.tx, this.value.take().unwrap()) {
                Ok(()) => {
                    let mut signal_queues = this.inner.signal_queues();
                    if this.poll_cnt > 0 {
                        signal_queues.remove_send(&this.waker as *const AtomicWaker as usize);
                    }

                    if let Some(waker) = signal_queues.pop_recv() {
                        drop(signal_queues);
                        waker.as_ref().wake();
                    } else {
                        drop(signal_queues);
                    }

                    return Poll::Ready(Ok(()));
                }
                Err(TrySendError::Full(value)) => {
                    this.value = Some(value);
                }
                Err(TrySendError::Disconnected(value)) => {
                    let mut signal_queues = this.inner.signal_queues();

                    while let Some(waker) = signal_queues.pop_send() {
                        waker.as_ref().wake();
                    }

                    while let Some(waker) = signal_queues.pop_recv() {
                        waker.as_ref().wake();
                    }

                    drop(signal_queues);
                    return Poll::Ready(Err(SendError(value)));
                }
            }
        }

        // if this.poll_cnt < 2 {
        //     cx.waker().wake_by_ref();
        //     return Poll::Pending;
        // }

        this.waker.register(cx.waker());
        let mut signal_queues = this.inner.signal_queues();
        // we failed to send, so we try one last time and enqueue the future. this is done inside the lock to prevent data races.
        // benchmarks show that explicitly dropping the mutex here is faster
        match this.tx.try_send(this.value.take().unwrap()) {
            Ok(()) => {
                if this.poll_cnt > 0 {
                    signal_queues.remove_send(&this.waker as *const AtomicWaker as usize);
                }

                if let Some(waker) = signal_queues.pop_recv() {
                    drop(signal_queues);
                    waker.as_ref().wake();
                } else {
                    drop(signal_queues);
                }

                return Poll::Ready(Ok(()));
            }
            Err(crossbeam_channel::TrySendError::Full(value)) => {
                this.value = Some(value);
            }
            Err(crossbeam_channel::TrySendError::Disconnected(value)) => {
                while let Some(waker) = signal_queues.pop_send() {
                    waker.as_ref().wake();
                }

                while let Some(waker) = signal_queues.pop_recv() {
                    waker.as_ref().wake();
                }

                drop(signal_queues);
                return Poll::Ready(Err(SendError(value)));
            }
        }

        let waker_ptr = &this.waker as *const AtomicWaker as usize;
        if this.poll_cnt > 0 {
            signal_queues.remove_send(waker_ptr);
        }
        this.poll_cnt += 1;
        signal_queues.add_send(waker_ptr);

        if let Some(waker) = signal_queues.pop_recv() {
            drop(signal_queues);
            waker.as_ref().wake();
        } else {
            drop(signal_queues);
        }

        Poll::Pending
    }
}

impl<'a, T> Drop for SendFut<'a, T> {
    fn drop(&mut self) {
        if self.poll_cnt >= 1 {
            let mut signal_queues = self.inner.signal_queues();
            signal_queues.remove_send(&self.waker as *const AtomicWaker as usize);
        }
    }
}