Documentation
use core::cell::UnsafeCell;
use core::future::Future;
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::sync::atomic::Ordering;
use core::task::{Context, Poll, Waker};

use alloc::boxed::Box;

use crate::sync::mutual::Mutual;
use crate::sync::{InitState, Waiter, Waiters, WaitersExt};

pub struct OnceLock<T> {
    state: Mutual<InitState>,
    value: UnsafeCell<MaybeUninit<T>>,
    waiters: Waiters,
}

unsafe impl<T: Send> Send for OnceLock<T> {}
unsafe impl<T: Send + Sync> Sync for OnceLock<T> {}

impl<T> OnceLock<T> {
    pub const fn new() -> Self {
        Self {
            state: Mutual::new(),
            value: UnsafeCell::new(MaybeUninit::uninit()),
            waiters: Waiters::new(),
        }
    }

    pub fn get(&self) -> Option<&T> {
        if self.state.is(&InitState::Initialized) {
            Some(unsafe { (*self.value.get()).assume_init_ref() })
        } else {
            None
        }
    }

    pub fn get_mut(&mut self) -> Option<&mut T> {
        if self.state.is(&InitState::Initialized) {
            Some(unsafe { (*self.value.get()).assume_init_mut() })
        } else {
            None
        }
    }

    pub async fn get_or_init<F, Fut>(&self, f: F) -> &T
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = T>,
    {
        if self.state.is(&InitState::Initialized) {
            return unsafe { (*self.value.get()).assume_init_ref() };
        }

        if self.state.is(&InitState::Uninitialized) {
            match self.state.compare_exchange(
                InitState::Uninitialized,
                InitState::Initializing,
                Ordering::AcqRel,
                Ordering::Acquire,
            ) {
                Ok(_) => {
                    let value = f().await;
                    unsafe {
                        (*self.value.get()).write(value);
                    }
                    self.state.set(InitState::Initialized);
                    self.waiters.notify_all();
                    return unsafe { (*self.value.get()).assume_init_ref() };
                }
                Err(_) => {}
            }
        }

        Wait { once: self }.await
    }

    pub async fn get_or_try_init<F, Fut, E>(&self, f: F) -> Result<&T, E>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<T, E>>,
    {
        if self.state.is(&InitState::Initialized) {
            return Ok(unsafe { (*self.value.get()).assume_init_ref() });
        }

        if self.state.is(&InitState::Uninitialized) {
            match self.state.compare_exchange(
                InitState::Uninitialized,
                InitState::Initializing,
                Ordering::AcqRel,
                Ordering::Acquire,
            ) {
                Ok(_) => match f().await {
                    Ok(value) => {
                        unsafe {
                            (*self.value.get()).write(value);
                        }
                        self.state.set(InitState::Initialized);
                        self.waiters.notify_all();
                        return Ok(unsafe { (*self.value.get()).assume_init_ref() });
                    }
                    Err(e) => {
                        self.state.set(InitState::Uninitialized);
                        self.waiters.notify_all();
                        return Err(e);
                    }
                },
                Err(_) => {}
            }
        }

        Ok(Wait { once: self }.await)
    }

    pub fn set(&self, value: T) -> Result<(), T> {
        match self.state.compare_exchange(
            InitState::Uninitialized,
            InitState::Initializing,
            Ordering::AcqRel,
            Ordering::Acquire,
        ) {
            Ok(_) => {
                unsafe {
                    (*self.value.get()).write(value);
                }
                self.state.set(InitState::Initialized);
                self.waiters.notify_all();
                Ok(())
            }
            Err(_) => Err(value),
        }
    }

    pub fn take(&mut self) -> Option<T> {
        if self.state.is(&InitState::Initialized) {
            self.state.set(InitState::Uninitialized);
            Some(unsafe { (*self.value.get()).assume_init_read() })
        } else {
            None
        }
    }

    pub fn into_inner(mut self) -> Option<T> {
        self.take()
    }
}

impl<T> Drop for OnceLock<T> {
    fn drop(&mut self) {
        if self.state.is(&InitState::Initialized) {
            unsafe {
                (*self.value.get()).assume_init_drop();
            }
        }
    }
}

enum InitResult {
    Initialized,
    Failed,
}

struct Wait<'a, T> {
    once: &'a OnceLock<T>,
}

impl<'a, T> Future for Wait<'a, T> {
    type Output = &'a T;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        loop {
            if self.once.state.is(&InitState::Initialized) {
                return Poll::Ready(unsafe { (*self.once.value.get()).assume_init_ref() });
            }

            if self.once.state.is(&InitState::Uninitialized) {
                panic!("AsyncOnceLock initialization failed");
            }

            if self.once.state.is(&InitState::Initializing) {
                unsafe {
                    self.once
                        .waiters
                        .enqueue(Waiter::from_waker(cx.waker().clone()))
                };

                if self.once.state.is(&InitState::Initialized) {
                    return Poll::Ready(unsafe { (*self.once.value.get()).assume_init_ref() });
                }
                return Poll::Pending;
            }
        }
    }
}

mod tests {
    use super::*;

    #[test]
    fn test_sync_get() {
        let once = OnceLock::<i32>::new();
        assert_eq!(once.get(), None);

        let _ = once.set(42);
        assert_eq!(once.get(), Some(&42));
    }
}