use std::{cell::UnsafeCell, future::Future, sync::atomic::Ordering};
use atomic_enum::atomic_enum;
use tokio::sync::{
oneshot::{self, Receiver},
Mutex,
};
#[derive(Debug, thiserror::Error)]
pub enum PreloaderError {
#[error("Preloader is not loaded")]
NotLoaded,
#[error("Preloader is loading")]
Loading,
}
type Result<T> = std::result::Result<T, PreloaderError>;
#[atomic_enum]
enum PreloaderState {
Idle,
Start,
Loading,
Loaded,
}
pub struct Preloader<T: Send + 'static> {
state: AtomicPreloaderState,
handle: Mutex<Option<Receiver<T>>>,
value: UnsafeCell<Option<T>>,
}
unsafe impl<T: Send + 'static> Send for Preloader<T> {}
unsafe impl<T: Send + 'static> Sync for Preloader<T> {}
impl<T: Send + 'static> Preloader<T> {
pub fn new() -> Self {
Self {
state: AtomicPreloaderState::new(PreloaderState::Idle),
handle: Mutex::new(None),
value: UnsafeCell::new(None),
}
}
pub async fn load(&self, future: impl Future<Output = T> + Send + 'static) {
let Ok(PreloaderState::Idle) = self.state.compare_exchange(
PreloaderState::Idle,
PreloaderState::Start,
Ordering::Relaxed,
Ordering::Relaxed,
) else {
return;
};
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let value = future.await;
_ = tx.send(value);
});
self.set_handle(rx).await;
}
pub async fn get(&self) -> Result<&T> {
match self.state.load(Ordering::Relaxed) {
PreloaderState::Idle | PreloaderState::Start => {
return Err(PreloaderError::NotLoaded);
}
PreloaderState::Loading => {
let mut handle = self.handle.lock().await;
if let Some(handle) = handle.take() {
let value = handle.await.map_err(|_| PreloaderError::Loading)?;
self.set_value(value);
return Ok(self.get_value());
} else {
return Ok(self.get_value());
}
}
PreloaderState::Loaded => {
return Ok(self.get_value());
}
}
}
pub async fn take(self) -> Result<T> {
match self.get().await {
Ok(_) => self.take_value(),
Err(e) => Err(e),
}
}
pub unsafe fn get_unchecked(&self) -> &T {
match self.state.load(Ordering::Relaxed) {
PreloaderState::Idle | PreloaderState::Start => {
panic!("Preloader is not loaded");
}
PreloaderState::Loading => {
return self.get_value();
}
PreloaderState::Loaded => {
return self.get_value();
}
}
}
pub fn try_get(&self) -> Result<&T> {
match self.state.load(Ordering::Relaxed) {
PreloaderState::Idle | PreloaderState::Start => {
return Err(PreloaderError::NotLoaded);
}
PreloaderState::Loading => {
let mut handle = self
.handle
.try_lock()
.map_err(|_| PreloaderError::Loading)?;
if let Some(handle) = handle.as_mut() {
let value = handle.try_recv().map_err(|_| PreloaderError::Loading)?;
self.set_value(value);
}
return Ok(self.get_value());
}
PreloaderState::Loaded => {
return Ok(self.get_value());
}
}
}
pub unsafe fn try_get_unchecked(&self) -> &T {
match self.state.load(Ordering::Relaxed) {
PreloaderState::Idle | PreloaderState::Start => {
panic!("Preloader is not loaded");
}
PreloaderState::Loading => {
panic!("Preloader is loading");
}
PreloaderState::Loaded => self.get_value(),
}
}
#[inline]
async fn set_handle(&self, handle: Receiver<T>) {
*self.handle.lock().await = Some(handle);
self.state.store(PreloaderState::Loading, Ordering::Release);
}
#[inline]
fn get_value(&self) -> &T {
unsafe { &*self.value.get() }.as_ref().unwrap()
}
#[inline]
fn set_value(&self, value: T) {
unsafe { *self.value.get() = Some(value) };
if let Ok(mut handle) = self.handle.try_lock() {
*handle = None;
}
self.state.store(PreloaderState::Loaded, Ordering::Release);
}
#[inline]
fn take_value(self) -> Result<T> {
unsafe {
let value = (*self.value.get()).take();
if let Some(value) = value {
Ok(value)
} else {
Err(PreloaderError::NotLoaded)
}
}
}
pub fn is_loaded(&self) -> bool {
self.try_get().is_ok()
}
}