extern crate alloc;
use alloc::{string::String, sync::Arc};
use core::{cell::UnsafeCell, num::NonZeroU64};
use arceos_api::task::{self as api, AxTaskHandle};
use axerrno::ax_err_type;
use crate::io;
#[derive(Eq, PartialEq, Clone, Copy, Debug)]
pub struct ThreadId(NonZeroU64);
pub struct Thread {
id: ThreadId,
}
impl ThreadId {
pub fn as_u64(&self) -> NonZeroU64 {
self.0
}
}
impl Thread {
fn from_id(id: u64) -> Self {
Self {
id: ThreadId(NonZeroU64::new(id).unwrap()),
}
}
pub fn id(&self) -> ThreadId {
self.id
}
}
#[derive(Debug)]
pub struct Builder {
name: Option<String>,
stack_size: Option<usize>,
}
impl Builder {
pub const fn new() -> Builder {
Builder {
name: None,
stack_size: None,
}
}
pub fn name(mut self, name: String) -> Builder {
self.name = Some(name);
self
}
pub fn stack_size(mut self, size: usize) -> Builder {
self.stack_size = Some(size);
self
}
pub fn spawn<F, T>(self, f: F) -> io::Result<JoinHandle<T>>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
unsafe { self.spawn_unchecked(f) }
}
unsafe fn spawn_unchecked<F, T>(self, f: F) -> io::Result<JoinHandle<T>>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
let name = self.name.unwrap_or_default();
let stack_size = self
.stack_size
.unwrap_or(arceos_api::config::TASK_STACK_SIZE);
let my_packet = Arc::new(Packet {
result: UnsafeCell::new(None),
});
let their_packet = my_packet.clone();
let main = move || {
let ret = f();
unsafe { *their_packet.result.get() = Some(ret) };
drop(their_packet);
};
let task = api::ax_spawn(main, name, stack_size);
Ok(JoinHandle {
thread: Thread::from_id(task.id()),
native: task,
packet: my_packet,
})
}
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
pub fn current() -> Thread {
let id = api::ax_current_task_id();
Thread::from_id(id)
}
pub fn spawn<T, F>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
Builder::new().spawn(f).expect("failed to spawn thread")
}
struct Packet<T> {
result: UnsafeCell<Option<T>>,
}
unsafe impl<T> Sync for Packet<T> {}
pub struct JoinHandle<T> {
native: AxTaskHandle,
thread: Thread,
packet: Arc<Packet<T>>,
}
unsafe impl<T> Send for JoinHandle<T> {}
unsafe impl<T> Sync for JoinHandle<T> {}
impl<T> JoinHandle<T> {
pub fn thread(&self) -> &Thread {
&self.thread
}
pub fn join(mut self) -> io::Result<T> {
api::ax_wait_for_exit(self.native);
Arc::get_mut(&mut self.packet)
.unwrap()
.result
.get_mut()
.take()
.ok_or_else(|| ax_err_type!(BadState))
}
}