use std::cmp::Ordering;
use std::convert::TryFrom;
use std::fmt::{Display, Formatter};
use std::marker::PhantomData;
use std::time::Duration;
use serde::{Deserialize, Serialize, Serializer};
use crate::event_label::{End, TJoin};
use crate::msg::Message;
use crate::must::Must;
use crate::runtime::execution::ExecutionState;
use crate::runtime::task::TaskId;
use crate::runtime::thread::{self, switch};
use crate::Val;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Deserialize)]
#[serde(try_from = "String")]
pub struct ThreadId {
opaque_id: u32,
}
impl Serialize for ThreadId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("t{}", self.opaque_id))
}
}
impl Display for ThreadId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!("t{}", self.opaque_id))
}
}
pub struct ThreadIdFromStrError {
msg: String,
}
impl Display for ThreadIdFromStrError {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
f.write_str(&self.msg)
}
}
impl TryFrom<String> for ThreadId {
type Error = ThreadIdFromStrError;
fn try_from(s: String) -> Result<Self, Self::Error> {
if s.starts_with('t') {
let mut num = s.clone();
num.remove(0);
match num.parse::<u32>() {
Ok(tid) => Ok(ThreadId { opaque_id: tid }),
Err(_) => Err(ThreadIdFromStrError {
msg: format!("Can't parse {} as a number", &s),
}),
}
} else {
Err(ThreadIdFromStrError {
msg: format!("`{}` should begin with `t`", &s),
})
}
}
}
pub fn construct_thread_id(numeric_id: u32) -> ThreadId {
ThreadId {
opaque_id: numeric_id,
}
}
impl From<ThreadId> for u32 {
fn from(tid: ThreadId) -> Self {
tid.opaque_id
}
}
impl From<ThreadId> for usize {
fn from(tid: ThreadId) -> Self {
tid.opaque_id as usize
}
}
impl ThreadId {
pub(crate) fn to_number(self) -> u32 {
self.opaque_id
}
}
pub fn main_thread_id() -> ThreadId {
ThreadId { opaque_id: 0 }
}
impl PartialOrd for ThreadId {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ThreadId {
fn cmp(&self, other: &Self) -> Ordering {
self.opaque_id.cmp(&other.opaque_id)
}
}
#[derive(Debug, Clone)]
pub struct Thread {
pub(crate) name: Option<String>,
pub(crate) id: ThreadId,
}
impl Thread {
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn id(&self) -> ThreadId {
self.id
}
}
pub fn spawn<F, T>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Message + 'static,
{
switch();
let jh = spawn_without_switch(f, None, false, None, None);
switch();
jh
}
pub fn spawn_daemon<F, T>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Message + 'static,
{
switch();
let jh = spawn_without_switch(f, None, true, None, None);
switch();
jh
}
pub(crate) fn spawn_without_switch<F, T>(
f: F,
name: Option<String>,
is_daemon: bool,
stack_size: Option<usize>,
sym_cid: Option<ThreadId>,
) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Message + 'static,
{
let stack_size =
stack_size.unwrap_or_else(|| ExecutionState::with(|s| s.must.borrow().config().stack_size));
let (task_id, tid) = {
let f = move || {
let ret = f();
ExecutionState::with(|state| {
let pos = state.next_pos();
state
.must
.borrow_mut()
.handle_tend(End::new(pos, Val::new(ret)));
Must::unstuck_joiners(state, pos.thread);
});
};
let cid = ExecutionState::spawn_thread(f, stack_size, name.clone());
let tid = ExecutionState::with(|state| {
let pos = state.next_pos();
let tid = state.must.borrow().next_thread_id(&pos);
state
.must
.borrow_mut()
.handle_tcreate(tid, cid, sym_cid, pos, name.clone(), is_daemon);
tid
});
(cid, tid)
};
let thread = Thread { id: tid, name };
JoinHandle {
task_id,
thread,
_p: PhantomData::<T>,
}
}
#[derive(Debug)]
pub struct JoinHandle<T> {
task_id: TaskId,
thread: Thread,
_p: PhantomData<T>,
}
impl<T: 'static> JoinHandle<T> {
pub fn join(self) -> std::thread::Result<T> {
let ret = loop {
thread::switch();
let val = ExecutionState::with(|s| {
let target_task_id = s.get(self.task_id).id();
let target_id = s.must.borrow().to_thread_id(target_task_id);
let pos = s.next_pos();
s.must.borrow_mut().handle_tjoin(TJoin::new(pos, target_id))
});
if let Some(message) = val {
if message.is_pending() {
ExecutionState::with(|s| s.current_mut().stuck());
} else {
break message;
}
}
ExecutionState::with(|s| s.prev_pos());
};
let actual_type = &ret.type_name;
Ok(*(ret.as_any().downcast().unwrap_or_else(|_| {
panic!(
"Expected a thread result of {}, but got {}",
std::any::type_name::<T>(),
actual_type
);
})))
}
pub fn thread(&self) -> &Thread {
&self.thread
}
}
pub fn sleep(_dur: Duration) {
thread::switch();
}
pub fn current() -> Thread {
let (tid, name) = ExecutionState::with(|s| {
let me = s.current();
let tid = s.must.borrow_mut().to_thread_id(me.id());
(tid, me.name())
});
Thread { id: tid, name }
}
pub fn current_id() -> ThreadId {
current().id()
}
#[derive(Debug, Default)]
pub struct Builder {
name: Option<String>,
stack_size: Option<usize>,
}
impl Builder {
pub fn new() -> Self {
Self {
name: None,
stack_size: None,
}
}
pub fn name(mut self, name: String) -> Self {
self.name = Some(name);
self
}
pub fn stack_size(mut self, stack_size: usize) -> Self {
self.stack_size = Some(stack_size);
self
}
pub fn spawn<F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Message + 'static,
{
switch();
let jh = Ok(spawn_without_switch(
f,
self.name,
false,
self.stack_size,
None,
));
switch();
jh
}
pub fn spawn_daemon<F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Message + 'static,
{
switch();
let jh = Ok(spawn_without_switch(
f,
self.name,
true,
self.stack_size,
None,
));
switch();
jh
}
}
pub struct LocalKey<T: 'static> {
#[doc(hidden)]
pub init: fn() -> T,
#[doc(hidden)]
pub _p: PhantomData<T>,
}
unsafe impl<T> Send for LocalKey<T> {}
unsafe impl<T> Sync for LocalKey<T> {}
impl<T: 'static> std::fmt::Debug for LocalKey<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalKey").finish_non_exhaustive()
}
}
impl<T: 'static> LocalKey<T> {
pub fn with<F, R>(&'static self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
self.try_with(f).expect(
"cannot access a Thread Local Storage value \
during or after destruction",
)
}
pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
where
F: FnOnce(&T) -> R,
{
let value = self.get().unwrap_or_else(|| {
self.get().unwrap()
})?;
Ok(f(value))
}
fn get(&'static self) -> Option<Result<&'static T, AccessError>> {
ExecutionState::with(|_state| {
Some(Err(AccessError))
})
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[non_exhaustive]
pub struct AccessError;
impl std::fmt::Display for AccessError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt("already destroyed", f)
}
}
impl std::error::Error for AccessError {}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn threadid_is_serializable() {
let tid = ThreadId { opaque_id: 123 };
let str = serde_json::to_string_pretty(&tid).unwrap();
assert_eq!("\"t123\"", str);
let deserialized: ThreadId = serde_json::from_str(&str).unwrap();
assert_eq!(deserialized, tid);
}
}