Documentation
use core::pin::pin;
use core::sync::atomic::Ordering::*;
use core::{
    cell::UnsafeCell,
    marker::PhantomData,
    mem,
    ops::{Deref, DerefMut},
    pin::Pin,
    ptr,
    sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize},
    task::{Poll, Waker},
};

use alloc::boxed::Box;

use super::q::Queue;
use crate::Guard;
use crate::sync::{Waiter, Waiters};

pub struct Mutex<T> {
    locked: AtomicBool,
    waiters: Waiters,
    value: UnsafeCell<T>,
}
unsafe impl<T: Send> Send for Mutex<T> {}
unsafe impl<T: Send> Sync for Mutex<T> {}
impl<'a, T> Deref for Guard<'a, Mutex<T>> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        unsafe { &self.0.value.get().as_ref().unwrap() }
    }
}

impl<'a, T> DerefMut for Guard<'a, Mutex<T>> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { self.0.value.get().as_mut().unwrap() }
    }
}

impl<T> Mutex<T> {
    pub const fn new(value: T) -> Self {
        Self {
            locked: AtomicBool::new(false),
            waiters: Queue::new(),
            value: UnsafeCell::new(value),
        }
    }

    pub fn try_lock(&self) -> Option<Guard<'_, Self>> {
        self.locked
            .compare_exchange(false, true, Acquire, Relaxed)
            .map(|_| Guard(self))
            .ok()
    }

    pub fn lock(
        &self,
        waiter: Waiter,
    ) -> Lock<'_, Self, impl Future<Output = Guard<'_, Self>> + Send>
    where
        T: Send,
    {
        Lock(Box::pin(Key {
            lock: &self,
            node: waiter.into(),
            queued: false,
            marker: PhantomData,
        }))
    }

    pub fn enqueue(&self, waiter: Waiter) {
        unsafe {
            self.waiters.enqueue(waiter);
        }
    }
}

pub struct Key<'a, T> {
    lock: &'a Mutex<T>,
    node: Option<Waiter>,
    queued: bool,
    marker: PhantomData<&'a mut T>,
}

impl<'a, T> Future for Key<'a, T> {
    type Output = Guard<'a, Mutex<T>>;

    fn poll(
        mut self: Pin<&mut Self>,
        cx: &mut core::task::Context<'_>,
    ) -> core::task::Poll<Self::Output> {
        let Self {
            lock, node, queued, ..
        } = &mut *self;
        if lock.try_lock().is_some() {
            return Poll::Ready(Guard(lock));
        }

        let node = node.take().expect("lock() requires a waiter");

        node.assign_waker(cx.waker().clone());

        if node.signaled() {
            loop {
                let locked = lock.locked.load(Relaxed);
                if locked {
                    break;
                }
                let Some(poll) = lock
                    .locked
                    .compare_exchange_weak(false, true, Acquire, Relaxed)
                    .map(|_| Guard(*lock))
                    .map(Poll::Ready)
                    .ok()
                else {
                    continue;
                };
                return poll;
            }
        }

        if !*queued {
            unsafe {
                lock.enqueue(node);
                *queued = true;
            }
        }

        Poll::Pending
    }
}

pub type LockKey<'a, T> = Lock<'a, Mutex<T>, Key<'a, T>>;

/// Future that returns a guard
pub struct Lock<'a, T: 'a, F>(Pin<Box<F>>)
where
    F: Future<Output = Guard<'a, T>>;

impl<'a, T: 'a, F: Future<Output = Guard<'a, T>>> Future for Lock<'a, T, F> {
    type Output = F::Output;

    fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
        pin!(&mut self.0).poll(cx)
    }
}