use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
use crate::io;
use crate::task::{JoinHandle, Task, TaskLocalsWrapper};
#[derive(Debug, Default)]
pub struct Builder {
pub(crate) name: Option<String>,
}
impl Builder {
#[inline]
pub fn new() -> Builder {
Builder { name: None }
}
#[inline]
pub fn name(mut self, name: String) -> Builder {
self.name = Some(name);
self
}
fn build<F, T>(self, future: F) -> SupportTaskLocals<F>
where
F: Future<Output = T>,
{
let name = self.name.map(Arc::new);
let task = Task::new(name);
#[cfg(not(target_os = "unknown"))]
once_cell::sync::Lazy::force(&crate::rt::RUNTIME);
let tag = TaskLocalsWrapper::new(task.clone());
SupportTaskLocals { tag, future }
}
#[cfg(not(target_os = "unknown"))]
pub fn spawn<F, T>(self, future: F) -> io::Result<JoinHandle<T>>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let wrapped = self.build(future);
kv_log_macro::trace!("spawn", {
task_id: wrapped.tag.id().0,
parent_task_id: TaskLocalsWrapper::get_current(|t| t.id().0).unwrap_or(0),
});
let task = wrapped.tag.task().clone();
let smol_task = smol::Task::spawn(wrapped).into();
Ok(JoinHandle::new(smol_task, task))
}
#[cfg(all(not(target_os = "unknown"), feature = "unstable"))]
pub fn local<F, T>(self, future: F) -> io::Result<JoinHandle<T>>
where
F: Future<Output = T> + 'static,
T: 'static,
{
let wrapped = self.build(future);
kv_log_macro::trace!("spawn_local", {
task_id: wrapped.tag.id().0,
parent_task_id: TaskLocalsWrapper::get_current(|t| t.id().0).unwrap_or(0),
});
let task = wrapped.tag.task().clone();
let smol_task = smol::Task::local(wrapped).into();
Ok(JoinHandle::new(smol_task, task))
}
#[cfg(all(target_arch = "wasm32", feature = "unstable"))]
pub fn local<F, T>(self, future: F) -> io::Result<JoinHandle<T>>
where
F: Future<Output = T> + 'static,
T: 'static,
{
use futures_channel::oneshot::channel;
let (sender, receiver) = channel();
let wrapped = self.build(async move {
let res = future.await;
let _ = sender.send(res);
});
kv_log_macro::trace!("spawn_local", {
task_id: wrapped.tag.id().0,
parent_task_id: TaskLocalsWrapper::get_current(|t| t.id().0).unwrap_or(0),
});
let task = wrapped.tag.task().clone();
wasm_bindgen_futures::spawn_local(wrapped);
Ok(JoinHandle::new(receiver, task))
}
#[cfg(all(target_arch = "wasm32", not(feature = "unstable")))]
pub(crate) fn local<F, T>(self, future: F) -> io::Result<JoinHandle<T>>
where
F: Future<Output = T> + 'static,
T: 'static,
{
use futures_channel::oneshot::channel;
let (sender, receiver) = channel();
let wrapped = self.build(async move {
let res = future.await;
let _ = sender.send(res);
});
kv_log_macro::trace!("spawn_local", {
task_id: wrapped.tag.id().0,
parent_task_id: TaskLocalsWrapper::get_current(|t| t.id().0).unwrap_or(0),
});
let task = wrapped.tag.task().clone();
wasm_bindgen_futures::spawn_local(wrapped);
Ok(JoinHandle::new(receiver, task))
}
#[cfg(not(target_os = "unknown"))]
pub fn blocking<F, T>(self, future: F) -> T
where
F: Future<Output = T>,
{
let wrapped = self.build(future);
kv_log_macro::trace!("block_on", {
task_id: wrapped.tag.id().0,
parent_task_id: TaskLocalsWrapper::get_current(|t| t.id().0).unwrap_or(0),
});
unsafe { TaskLocalsWrapper::set_current(&wrapped.tag, || smol::run(wrapped)) }
}
}
pin_project! {
struct SupportTaskLocals<F> {
tag: TaskLocalsWrapper,
#[pin]
future: F,
}
}
impl<F: Future> Future for SupportTaskLocals<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe {
TaskLocalsWrapper::set_current(&self.tag, || {
let this = self.project();
this.future.poll(cx)
})
}
}
}