#![cfg_attr(target_family = "wasm", allow(dead_code))]
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use anyhow::bail;
use fedimint_core::time::now;
use fedimint_logging::LOG_TASK;
#[cfg(target_family = "wasm")]
use futures::channel::oneshot;
use futures::lock::Mutex;
pub use imp::*;
use thiserror::Error;
#[cfg(not(target_family = "wasm"))]
use tokio::sync::oneshot;
use tokio::sync::watch;
#[cfg(not(target_family = "wasm"))]
use tokio::task::{JoinError, JoinHandle};
use tracing::{error, info, warn};
#[cfg(target_family = "wasm")]
type JoinHandle<T> = futures::future::Ready<anyhow::Result<T>>;
#[cfg(target_family = "wasm")]
type JoinError = anyhow::Error;
#[derive(Debug, Error)]
#[error("deadline has elapsed")]
pub struct Elapsed;
#[derive(Debug)]
struct TaskGroupInner {
    on_shutdown_tx: watch::Sender<bool>,
    on_shutdown_rx: watch::Receiver<bool>,
    join: Mutex<VecDeque<(String, JoinHandle<()>)>>,
    subgroups: std::sync::Mutex<Vec<TaskGroup>>,
}
impl Default for TaskGroupInner {
    fn default() -> Self {
        let (on_shutdown_tx, on_shutdown_rx) = watch::channel(false);
        Self {
            on_shutdown_tx,
            on_shutdown_rx,
            join: Mutex::new(Default::default()),
            subgroups: std::sync::Mutex::new(vec![]),
        }
    }
}
impl TaskGroupInner {
    pub fn shutdown(&self) {
        self.on_shutdown_tx
            .send(true)
            .expect("We must have on_shutdown_rx around so this never fails");
        let subgroups = self.subgroups.lock().expect("locking failed").clone();
        for subgroup in subgroups {
            subgroup.inner.shutdown();
        }
    }
}
#[derive(Clone, Default, Debug)]
pub struct TaskGroup {
    inner: Arc<TaskGroupInner>,
}
impl TaskGroup {
    pub fn new() -> Self {
        Self::default()
    }
    pub fn make_handle(&self) -> TaskHandle {
        TaskHandle {
            inner: self.inner.clone(),
        }
    }
    pub async fn make_subgroup(&self) -> TaskGroup {
        let new_tg = Self::new();
        self.inner
            .subgroups
            .lock()
            .expect("locking failed")
            .push(new_tg.clone());
        new_tg
    }
    pub fn shutdown(&self) {
        self.inner.shutdown()
    }
    pub async fn shutdown_join_all(
        self,
        join_timeout: Option<Duration>,
    ) -> Result<(), anyhow::Error> {
        self.shutdown();
        self.join_all(join_timeout).await
    }
    #[cfg(not(target_family = "wasm"))]
    pub fn install_kill_handler(&self) {
        use tokio::signal;
        async fn wait_for_shutdown_signal() {
            let ctrl_c = async {
                signal::ctrl_c()
                    .await
                    .expect("failed to install Ctrl+C handler");
            };
            #[cfg(unix)]
            let terminate = async {
                signal::unix::signal(signal::unix::SignalKind::terminate())
                    .expect("failed to install signal handler")
                    .recv()
                    .await;
            };
            #[cfg(not(unix))]
            let terminate = std::future::pending::<()>();
            tokio::select! {
                _ = ctrl_c => {},
                _ = terminate => {},
            }
        }
        spawn("kill handlers", {
            let task_group = self.clone();
            async move {
                wait_for_shutdown_signal().await;
                info!(
                    target: LOG_TASK,
                    "signal received, starting graceful shutdown"
                );
                task_group.shutdown();
            }
        });
    }
    #[cfg(not(target_family = "wasm"))]
    pub async fn spawn<Fut, R>(
        &mut self,
        name: impl Into<String>,
        f: impl FnOnce(TaskHandle) -> Fut + Send + 'static,
    ) -> oneshot::Receiver<R>
    where
        Fut: Future<Output = R> + Send + 'static,
        R: Send + 'static,
    {
        use tracing::debug;
        let name = name.into();
        let mut guard = TaskPanicGuard {
            name: name.clone(),
            inner: self.inner.clone(),
            completed: false,
        };
        let handle = self.make_handle();
        let (tx, rx) = oneshot::channel();
        if let Some(handle) = self::imp::spawn(name.as_str(), {
            let name = name.clone();
            async move {
                debug!("Starting task {name}");
                let r = f(handle).await;
                debug!("Finished task {name}");
                let _ = tx.send(r);
            }
        }) {
            self.inner.join.lock().await.push_back((name, handle));
        }
        guard.completed = true;
        rx
    }
    #[cfg(not(target_family = "wasm"))]
    pub async fn spawn_local<Fut>(
        &mut self,
        name: impl Into<String>,
        f: impl FnOnce(TaskHandle) -> Fut + 'static,
    ) where
        Fut: Future<Output = ()> + 'static,
    {
        let name = name.into();
        let mut guard = TaskPanicGuard {
            name: name.clone(),
            inner: self.inner.clone(),
            completed: false,
        };
        let handle = self.make_handle();
        if let Some(handle) = self::imp::spawn_local(name.as_str(), async move {
            f(handle).await;
        }) {
            self.inner.join.lock().await.push_back((name, handle));
        }
        guard.completed = true;
    }
    #[cfg(target_family = "wasm")]
    pub async fn spawn<Fut, R>(
        &mut self,
        name: impl Into<String>,
        f: impl FnOnce(TaskHandle) -> Fut + 'static,
    ) -> oneshot::Receiver<R>
    where
        Fut: Future<Output = R> + 'static,
        R: 'static,
    {
        let name = name.into();
        let mut guard = TaskPanicGuard {
            name: name.clone(),
            inner: self.inner.clone(),
            completed: false,
        };
        let handle = self.make_handle();
        let (tx, rx) = oneshot::channel();
        if let Some(handle) = self::imp::spawn(name.as_str(), async move {
            let _ = tx.send(f(handle).await);
        }) {
            self.inner.join.lock().await.push_back((name, handle));
        }
        guard.completed = true;
        rx
    }
    pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
        let deadline = timeout.map(|timeout| now() + timeout);
        let mut errors = vec![];
        self.join_all_inner(deadline, &mut errors).await;
        if errors.is_empty() {
            Ok(())
        } else {
            let num_errors = errors.len();
            bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
        }
    }
    #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
    #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
    pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
        let subgroups = self.inner.subgroups.lock().expect("locking failed").clone();
        for subgroup in subgroups {
            info!(target: LOG_TASK, "Waiting for subgroup to finish");
            subgroup.join_all_inner(deadline, errors).await;
            info!(target: LOG_TASK, "Subgroup finished");
        }
        while let Some((name, join)) = self.inner.join.lock().await.pop_front() {
            info!(target: LOG_TASK, task=%name, "Waiting for task to finish");
            let timeout = deadline.map(|deadline| {
                deadline
                    .duration_since(now())
                    .unwrap_or(Duration::from_millis(10))
            });
            #[cfg(not(target_family = "wasm"))]
            let join_future: Pin<Box<dyn Future<Output = _> + Send>> =
                if let Some(timeout) = timeout {
                    Box::pin(self::timeout(timeout, join))
                } else {
                    Box::pin(async move { Ok(join.await) })
                };
            #[cfg(target_family = "wasm")]
            let join_future: Pin<Box<dyn Future<Output = _>>> = if let Some(timeout) = timeout {
                Box::pin(self::timeout(timeout, join))
            } else {
                Box::pin(async move { Ok(join.await) })
            };
            match join_future.await {
                Ok(Ok(())) => {
                    info!(target: LOG_TASK, task=%name, "Task finished");
                }
                Ok(Err(e)) => {
                    error!(target: LOG_TASK, task=%name, error=%e, "Task panicked");
                    errors.push(e);
                }
                Err(Elapsed) => {
                    warn!(
                        target: LOG_TASK, task=%name,
                        "Timeout waiting for task to shut down"
                    )
                }
            }
        }
    }
}
pub struct TaskPanicGuard {
    name: String,
    inner: Arc<TaskGroupInner>,
    completed: bool,
}
impl TaskPanicGuard {
    pub fn is_shutting_down(&self) -> bool {
        *self.inner.on_shutdown_tx.borrow()
    }
}
impl Drop for TaskPanicGuard {
    fn drop(&mut self) {
        if !self.completed {
            info!(
                target: LOG_TASK,
                "Task {} shut down uncleanly. Shutting down task group.", self.name
            );
            self.inner.shutdown();
        }
    }
}
#[derive(Clone, Debug)]
pub struct TaskHandle {
    inner: Arc<TaskGroupInner>,
}
impl TaskHandle {
    pub fn is_shutting_down(&self) -> bool {
        *self.inner.on_shutdown_tx.borrow()
    }
    pub async fn make_shutdown_rx(&self) -> TaskShutdownToken {
        TaskShutdownToken::new(self.inner.on_shutdown_rx.clone())
    }
}
pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
impl TaskShutdownToken {
    fn new(mut rx: watch::Receiver<bool>) -> Self {
        Self(Box::pin(async move {
            let _ = rx.wait_for(|v| *v).await;
        }))
    }
}
impl Future for TaskShutdownToken {
    type Output = ();
    fn poll(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Self::Output> {
        self.0.as_mut().poll(cx)
    }
}
#[cfg(not(target_family = "wasm"))]
mod imp {
    pub use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
    use super::*;
    pub fn spawn<F, V: Send + 'static>(name: &str, future: F) -> Option<JoinHandle<V>>
    where
        F: Future<Output = V> + Send + 'static,
    {
        Some(
            tokio::task::Builder::new()
                .name(name)
                .spawn(future)
                .expect("spawn failed"),
        )
    }
    pub(crate) fn spawn_local<F>(name: &str, future: F) -> Option<JoinHandle<()>>
    where
        F: Future<Output = ()> + 'static,
    {
        Some(
            tokio::task::Builder::new()
                .name(name)
                .spawn_local(future)
                .expect("spawn failed"),
        )
    }
    pub fn block_in_place<F, R>(f: F) -> R
    where
        F: FnOnce() -> R,
    {
        tokio::task::block_in_place(f)
    }
    pub async fn sleep(duration: Duration) {
        tokio::time::sleep(duration).await
    }
    pub async fn sleep_until(deadline: Instant) {
        tokio::time::sleep_until(deadline.into()).await
    }
    pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed>
    where
        T: Future,
    {
        tokio::time::timeout(duration, future)
            .await
            .map_err(|_| Elapsed)
    }
}
#[cfg(target_family = "wasm")]
mod imp {
    pub use async_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard};
    use futures::FutureExt;
    use super::*;
    pub fn spawn<F>(_name: &str, future: F) -> Option<JoinHandle<()>>
    where
        F: Future<Output = ()> + 'static,
    {
        wasm_bindgen_futures::spawn_local(future);
        None
    }
    pub(crate) fn spawn_local<F>(_name: &str, future: F) -> Option<JoinHandle<()>>
    where
        F: Future<Output = ()> + 'static,
    {
        self::spawn(_name, future)
    }
    pub fn block_in_place<F, R>(f: F) -> R
    where
        F: FnOnce() -> R,
    {
        f()
    }
    pub async fn sleep(duration: Duration) {
        gloo_timers::future::sleep(duration.min(Duration::from_millis(i32::MAX as _))).await
    }
    pub async fn sleep_until(deadline: Instant) {
        sleep(deadline.saturating_duration_since(Instant::now())).await
    }
    pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed>
    where
        T: Future,
    {
        futures::pin_mut!(future);
        futures::select_biased! {
            value = future.fuse() => Ok(value),
            _ = sleep(duration).fuse() => Err(Elapsed),
        }
    }
}
#[macro_export]
macro_rules! async_trait_maybe_send {
    ($($tt:tt)*) => {
        #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
        #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
        $($tt)*
    };
}
#[cfg(not(target_family = "wasm"))]
#[macro_export]
macro_rules! maybe_add_send {
    ($($tt:tt)*) => {
        $($tt)* + Send
    };
}
#[cfg(target_family = "wasm")]
#[macro_export]
macro_rules! maybe_add_send {
    ($($tt:tt)*) => {
        $($tt)*
    };
}
#[cfg(not(target_family = "wasm"))]
#[macro_export]
macro_rules! maybe_add_send_sync {
    ($($tt:tt)*) => {
        $($tt)* + Send + Sync
    };
}
#[cfg(target_family = "wasm")]
#[macro_export]
macro_rules! maybe_add_send_sync {
    ($($tt:tt)*) => {
        $($tt)*
    };
}
#[cfg(target_family = "wasm")]
pub trait MaybeSend {}
#[cfg(not(target_family = "wasm"))]
pub trait MaybeSend: Send {}
#[cfg(not(target_family = "wasm"))]
impl<T: Send> MaybeSend for T {}
#[cfg(target_family = "wasm")]
impl<T> MaybeSend for T {}
#[cfg(target_family = "wasm")]
pub trait MaybeSync {}
#[cfg(not(target_family = "wasm"))]
pub trait MaybeSync: Sync {}
#[cfg(not(target_family = "wasm"))]
impl<T: Sync> MaybeSync for T {}
#[cfg(target_family = "wasm")]
impl<T> MaybeSync for T {}
#[cfg(test)]
mod tests {
    use super::*;
    #[test_log::test(tokio::test)]
    async fn shutdown_task_group_after() -> anyhow::Result<()> {
        let mut tg = TaskGroup::new();
        tg.spawn("shutdown waiter", |handle| async move {
            handle.make_shutdown_rx().await.await
        })
        .await;
        sleep(Duration::from_millis(10)).await;
        tg.shutdown_join_all(None).await?;
        Ok(())
    }
    #[test_log::test(tokio::test)]
    async fn shutdown_task_group_before() -> anyhow::Result<()> {
        let mut tg = TaskGroup::new();
        tg.spawn("shutdown waiter", |handle| async move {
            sleep(Duration::from_millis(10)).await;
            handle.make_shutdown_rx().await.await
        })
        .await;
        tg.shutdown_join_all(None).await?;
        Ok(())
    }
    #[test_log::test(tokio::test)]
    async fn shutdown_task_subgroup_after() -> anyhow::Result<()> {
        let tg = TaskGroup::new();
        tg.make_subgroup()
            .await
            .spawn("shutdown waiter", |handle| async move {
                handle.make_shutdown_rx().await.await
            })
            .await;
        sleep(Duration::from_millis(10)).await;
        tg.shutdown_join_all(None).await?;
        Ok(())
    }
    #[test_log::test(tokio::test)]
    async fn shutdown_task_subgroup_before() -> anyhow::Result<()> {
        let tg = TaskGroup::new();
        tg.make_subgroup()
            .await
            .spawn("shutdown waiter", |handle| async move {
                sleep(Duration::from_millis(10)).await;
                handle.make_shutdown_rx().await.await
            })
            .await;
        tg.shutdown_join_all(None).await?;
        Ok(())
    }
}