use crate::runtime::execution::ExecutionState;
use crate::runtime::task::TaskId;
use crate::runtime::thread;
use std::marker::PhantomData;
use std::time::Duration;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ThreadId {
task_id: TaskId,
}
impl From<ThreadId> for usize {
fn from(id: ThreadId) -> usize {
id.task_id.into()
}
}
#[derive(Debug, Clone)]
pub struct Thread {
name: Option<String>,
id: ThreadId,
}
impl Thread {
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn id(&self) -> ThreadId {
self.id
}
pub fn unpark(&self) {
ExecutionState::with(|s| {
s.get_mut(self.id.task_id).unpark();
});
thread::switch();
}
}
pub fn spawn<F, T>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
spawn_named(f, None, None)
}
fn spawn_named<F, T>(f: F, name: Option<String>, stack_size: Option<usize>) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
let stack_size = stack_size.unwrap_or_else(|| ExecutionState::with(|s| s.config.stack_size));
let result = std::sync::Arc::new(std::sync::Mutex::new(None));
let task_id = {
let result = std::sync::Arc::clone(&result);
let f = move || thread_fn(f, result);
ExecutionState::spawn_thread(f, stack_size, name.clone(), None)
};
thread::switch();
let thread = Thread {
id: ThreadId { task_id },
name,
};
JoinHandle {
task_id,
thread,
result,
}
}
pub(crate) fn thread_fn<F, T>(f: F, result: std::sync::Arc<std::sync::Mutex<Option<std::thread::Result<T>>>>)
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
let ret = f();
tracing::trace!("thread finished, dropping thread locals");
while let Some(local) = ExecutionState::with(|state| state.current_mut().pop_local()) {
tracing::trace!("dropping thread local {:p}", local);
drop(local);
}
tracing::trace!("done dropping thread locals");
*result.lock().unwrap() = Some(Ok(ret));
ExecutionState::with(|state| {
if let Some(waiter) = state.current_mut().take_waiter() {
state.get_mut(waiter).unblock();
}
});
}
#[derive(Debug)]
pub struct JoinHandle<T> {
task_id: TaskId,
thread: Thread,
result: std::sync::Arc<std::sync::Mutex<Option<std::thread::Result<T>>>>,
}
impl<T> JoinHandle<T> {
pub fn join(self) -> std::thread::Result<T> {
ExecutionState::with(|state| {
let me = state.current().id();
let target = state.get_mut(self.task_id);
if target.set_waiter(me) {
state.current_mut().block();
}
});
thread::switch();
ExecutionState::with(|state| {
let target = state.get_mut(self.task_id);
let clock = target.clock.clone();
state.update_clock(&clock);
});
self.result.lock().unwrap().take().expect("target should have finished")
}
pub fn thread(&self) -> &Thread {
&self.thread
}
}
pub fn yield_now() {
let waker = ExecutionState::with(|state| state.current().waker());
waker.wake_by_ref();
ExecutionState::request_yield();
thread::switch();
}
pub fn sleep(_dur: Duration) {
thread::switch();
}
pub fn current() -> Thread {
let (task_id, name) = ExecutionState::with(|s| {
let me = s.current();
(me.id(), me.name())
});
Thread {
id: ThreadId { task_id },
name,
}
}
pub fn park() {
let switch = ExecutionState::with(|s| s.current_mut().park());
if switch {
thread::switch();
}
}
pub fn park_timeout(_dur: Duration) {
park();
}
#[derive(Debug, Default)]
pub struct Builder {
name: Option<String>,
stack_size: Option<usize>,
}
impl Builder {
pub fn new() -> Self {
Self {
name: None,
stack_size: None,
}
}
pub fn name(mut self, name: String) -> Self {
self.name = Some(name);
self
}
pub fn stack_size(mut self, stack_size: usize) -> Self {
self.stack_size = Some(stack_size);
self
}
pub fn spawn<F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
Ok(spawn_named(f, self.name, self.stack_size))
}
}
pub struct LocalKey<T: 'static> {
#[doc(hidden)]
pub init: fn() -> T,
#[doc(hidden)]
pub _p: PhantomData<T>,
}
unsafe impl<T> Send for LocalKey<T> {}
unsafe impl<T> Sync for LocalKey<T> {}
impl<T: 'static> std::fmt::Debug for LocalKey<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalKey").finish_non_exhaustive()
}
}
impl<T: 'static> LocalKey<T> {
pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
self.try_with(f).expect(
"cannot access a Thread Local Storage value \
during or after destruction",
)
}
pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
where
F: FnOnce(&T) -> R,
{
let value = self.get().unwrap_or_else(|| {
let value = (self.init)();
ExecutionState::with(move |state| {
state.current_mut().init_local(self, value);
});
self.get().unwrap()
})?;
Ok(f(value))
}
fn get(&'static self) -> Option<Result<&T, AccessError>> {
unsafe fn extend_lt<'a, 'b, T>(t: &'a T) -> &'b T {
std::mem::transmute(t)
}
ExecutionState::with(|state| {
if let Ok(value) = state.current().local(self)? {
Some(Ok(unsafe { extend_lt(value) }))
} else {
Some(Err(AccessError))
}
})
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[non_exhaustive]
pub struct AccessError;
impl std::fmt::Display for AccessError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt("already destroyed", f)
}
}
impl std::error::Error for AccessError {}