use alloc::{boxed::Box, string::String, sync::Arc};
#[cfg(feature = "preempt")]
use core::sync::atomic::AtomicUsize;
use core::{
alloc::Layout,
cell::{Cell, UnsafeCell},
fmt,
mem::ManuallyDrop,
ops::Deref,
sync::atomic::{AtomicBool, AtomicI32, AtomicU8, AtomicU32, AtomicU64, Ordering},
task::{Context, Poll},
};
use ax_hal::context::TaskContext;
#[cfg(feature = "tls")]
use ax_hal::tls::TlsArea;
use ax_kspin::SpinNoIrq;
use ax_memory_addr::{VirtAddr, align_up_4k};
use futures_util::task::AtomicWaker;
#[cfg(feature = "lockdep")]
use crate::lockdep::HeldLockStack;
use crate::{AxCpuMask, AxTask, AxTaskRef, WaitQueue};
#[cfg(all(feature = "stack-canary", target_pointer_width = "64"))]
const STACK_END_MAGIC: usize = 0x57AC_CE11_57AC_CE11usize;
#[cfg(all(feature = "stack-canary", target_pointer_width = "32"))]
const STACK_END_MAGIC: usize = 0x57AC_CE11usize;
pub(crate) const TASK_STACK_ALIGN: usize = 16;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct TaskId(u64);
#[repr(u8)]
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum TaskState {
Running = 1,
Ready = 2,
Blocked = 3,
Exited = 4,
}
#[cfg(feature = "task-ext")]
#[extern_trait::extern_trait(
/// The impl proxy type for [`TaskExt`].
pub AxTaskExt
)]
pub trait TaskExt {
fn on_enter(&self) {}
fn on_leave(&self) {}
}
pub struct TaskInner {
id: TaskId,
name: SpinNoIrq<String>,
is_idle: bool,
is_init: bool,
entry: Cell<Option<Box<dyn FnOnce()>>>,
state: AtomicU8,
cpumask: SpinNoIrq<AxCpuMask>,
in_wait_queue: AtomicBool,
cpu_id: AtomicU32,
#[cfg(feature = "smp")]
on_cpu: AtomicBool,
#[cfg(feature = "irq")]
timer_ticket_id: AtomicU64,
#[cfg(feature = "preempt")]
need_resched: AtomicBool,
#[cfg(feature = "preempt")]
preempt_disable_count: AtomicUsize,
interrupted: AtomicBool,
interrupt_waker: AtomicWaker,
exit_code: AtomicI32,
wait_for_exit: WaitQueue,
kstack: TaskStack,
ctx: UnsafeCell<TaskContext>,
#[cfg(feature = "lockdep")]
held_locks: UnsafeCell<HeldLockStack>,
#[cfg(feature = "task-ext")]
task_ext: Option<AxTaskExt>,
#[cfg(feature = "tls")]
tls: TlsArea,
}
impl TaskId {
fn new() -> Self {
static ID_COUNTER: AtomicU64 = AtomicU64::new(1);
Self(ID_COUNTER.fetch_add(1, Ordering::Relaxed))
}
pub const fn as_u64(&self) -> u64 {
self.0
}
}
impl From<u8> for TaskState {
#[inline]
fn from(state: u8) -> Self {
match state {
1 => Self::Running,
2 => Self::Ready,
3 => Self::Blocked,
4 => Self::Exited,
_ => unreachable!(),
}
}
}
unsafe impl Send for TaskInner {}
unsafe impl Sync for TaskInner {}
impl TaskInner {
pub fn new<F>(entry: F, name: String, stack_size: usize) -> Self
where
F: FnOnce() + Send + 'static,
{
let kstack = TaskStack::alloc(align_up_4k(stack_size));
let mut t = Self::new_common(TaskId::new(), name, kstack);
debug!("new task: {}", t.id_name());
#[cfg(feature = "tls")]
let tls = VirtAddr::from(t.tls.tls_ptr() as usize);
#[cfg(not(feature = "tls"))]
let tls = VirtAddr::from(0);
let kstack_top = t.kstack.top();
t.entry = Cell::new(Some(Box::new(entry)));
t.ctx_mut()
.init(task_entry as *const () as usize, kstack_top, tls);
if t.name() == "idle" {
t.is_idle = true;
}
t
}
pub const fn id(&self) -> TaskId {
self.id
}
pub fn name(&self) -> String {
self.name.lock().clone()
}
pub fn set_name(&self, name: &str) {
*self.name.lock() = String::from(name);
}
pub fn id_name(&self) -> alloc::string::String {
alloc::format!("Task({}, {:?})", self.id.as_u64(), self.name())
}
pub fn join(&self) -> i32 {
crate::api::might_sleep();
self.wait_for_exit
.wait_until(|| self.state() == TaskState::Exited);
self.exit_code.load(Ordering::Acquire)
}
#[cfg(feature = "task-ext")]
pub fn task_ext(&self) -> Option<&AxTaskExt> {
self.task_ext.as_ref()
}
#[cfg(feature = "task-ext")]
pub fn task_ext_mut(&mut self) -> &mut Option<AxTaskExt> {
&mut self.task_ext
}
#[inline]
pub const fn ctx_mut(&mut self) -> &mut TaskContext {
self.ctx.get_mut()
}
#[cfg(feature = "uspace")]
pub fn switch_page_table(&self, root: ax_memory_addr::PhysAddr) {
unsafe { (*self.ctx.get()).set_page_table_root(root) };
unsafe { ax_hal::asm::write_user_page_table(root) };
ax_hal::asm::flush_tlb(None);
}
#[cfg(feature = "lockdep")]
pub(crate) fn with_held_locks<R>(&self, f: impl FnOnce(&mut HeldLockStack) -> R) -> R {
f(unsafe { &mut *self.held_locks.get() })
}
#[inline]
pub fn cpu_id(&self) -> u32 {
self.cpu_id.load(Ordering::Acquire)
}
#[inline]
pub fn cpumask(&self) -> AxCpuMask {
*self.cpumask.lock()
}
#[inline]
pub fn set_cpumask(&self, cpumask: AxCpuMask) {
*self.cpumask.lock() = cpumask
}
#[inline]
pub fn poll_interrupt(&self, cx: &Context) -> Poll<()> {
self.interrupt_waker.register(cx.waker());
if self.interrupted.swap(false, Ordering::AcqRel) {
Poll::Ready(())
} else {
Poll::Pending
}
}
#[inline]
pub fn clear_interrupt(&self) {
self.interrupted.store(false, Ordering::Release);
}
#[inline]
pub fn interrupt(&self) {
self.interrupted.store(true, Ordering::Release);
self.interrupt_waker.wake();
}
}
impl TaskInner {
fn new_common(id: TaskId, name: String, kstack: TaskStack) -> Self {
Self {
id,
name: SpinNoIrq::new(name),
is_idle: false,
is_init: false,
entry: Cell::new(None),
state: AtomicU8::new(TaskState::Ready as u8),
cpumask: SpinNoIrq::new(crate::api::cpu_mask_full()),
in_wait_queue: AtomicBool::new(false),
#[cfg(feature = "irq")]
timer_ticket_id: AtomicU64::new(0),
cpu_id: AtomicU32::new(0),
#[cfg(feature = "smp")]
on_cpu: AtomicBool::new(false),
#[cfg(feature = "preempt")]
need_resched: AtomicBool::new(false),
#[cfg(feature = "preempt")]
preempt_disable_count: AtomicUsize::new(0),
interrupted: AtomicBool::new(false),
interrupt_waker: AtomicWaker::new(),
exit_code: AtomicI32::new(0),
wait_for_exit: WaitQueue::new(),
kstack,
ctx: UnsafeCell::new(TaskContext::new()),
#[cfg(feature = "lockdep")]
held_locks: UnsafeCell::new(HeldLockStack::new()),
#[cfg(feature = "task-ext")]
task_ext: None,
#[cfg(feature = "tls")]
tls: TlsArea::alloc(),
}
}
pub(crate) fn new_init(name: String, kstack: TaskStack) -> Self {
let mut t = Self::new_common(TaskId::new(), name, kstack);
t.is_init = true;
#[cfg(feature = "smp")]
t.set_on_cpu(true);
if t.name() == "idle" {
t.is_idle = true;
}
t
}
pub(crate) fn into_arc(self) -> AxTaskRef {
Arc::new(AxTask::new(self))
}
#[inline]
pub fn state(&self) -> TaskState {
self.state.load(Ordering::Acquire).into()
}
#[inline]
pub(crate) fn set_state(&self, state: TaskState) {
self.state.store(state as u8, Ordering::Release)
}
#[inline]
pub(crate) fn transition_state(&self, current_state: TaskState, new_state: TaskState) -> bool {
self.state
.compare_exchange(
current_state as u8,
new_state as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
}
#[inline]
pub(crate) fn is_running(&self) -> bool {
matches!(self.state(), TaskState::Running)
}
#[inline]
pub(crate) fn is_ready(&self) -> bool {
matches!(self.state(), TaskState::Ready)
}
#[inline]
pub(crate) const fn is_init(&self) -> bool {
self.is_init
}
#[inline]
pub(crate) const fn is_idle(&self) -> bool {
self.is_idle
}
#[inline]
pub(crate) fn in_wait_queue(&self) -> bool {
self.in_wait_queue.load(Ordering::Acquire)
}
#[inline]
pub(crate) fn set_in_wait_queue(&self, in_wait_queue: bool) {
self.in_wait_queue.store(in_wait_queue, Ordering::Release);
}
#[inline]
#[cfg(feature = "irq")]
pub(crate) fn timer_ticket(&self) -> u64 {
self.timer_ticket_id.load(Ordering::Acquire)
}
#[inline]
#[cfg(feature = "irq")]
pub(crate) fn set_timer_ticket(&self, timer_ticket_id: u64) {
assert!(timer_ticket_id != 0);
self.timer_ticket_id
.store(timer_ticket_id, Ordering::Release);
}
#[inline]
#[cfg(feature = "irq")]
pub(crate) fn timer_ticket_expired(&self) {
self.timer_ticket_id.store(0, Ordering::Release);
}
#[inline]
#[cfg(feature = "preempt")]
pub(crate) fn set_preempt_pending(&self, pending: bool) {
self.need_resched.store(pending, Ordering::Release)
}
#[inline]
#[cfg(feature = "preempt")]
pub(crate) fn preempt_count(&self) -> usize {
self.preempt_disable_count.load(Ordering::Acquire)
}
#[inline]
#[cfg(feature = "preempt")]
pub(crate) fn can_preempt(&self, current_disable_count: usize) -> bool {
self.preempt_disable_count.load(Ordering::Acquire) == current_disable_count
}
#[inline]
#[cfg(feature = "preempt")]
pub(crate) fn disable_preempt(&self) {
self.preempt_disable_count.fetch_add(1, Ordering::Release);
}
#[inline]
#[cfg(feature = "preempt")]
pub(crate) fn enable_preempt(&self, resched: bool) {
if self.preempt_disable_count.fetch_sub(1, Ordering::Release) == 1 && resched {
Self::current_check_preempt_pending();
}
}
#[cfg(feature = "preempt")]
fn current_check_preempt_pending() {
use ax_kernel_guard::NoPreemptIrqSave;
let curr = crate::current();
if curr.need_resched.load(Ordering::Acquire) && curr.can_preempt(0) {
let mut rq = crate::current_run_queue::<NoPreemptIrqSave>();
if curr.need_resched.load(Ordering::Acquire) {
rq.preempt_resched()
}
}
}
pub(crate) fn notify_exit(&self, exit_code: i32) {
self.set_state(TaskState::Exited);
self.exit_code.store(exit_code, Ordering::Release);
self.wait_for_exit.notify_all(false);
}
#[inline]
pub(crate) const unsafe fn ctx_mut_ptr(&self) -> *mut TaskContext {
self.ctx.get()
}
#[cfg(feature = "stack-canary")]
#[inline]
pub(crate) fn check_stack_canary(&self) {
if self.kstack.is_canary_intact() {
return;
}
panic!(
"stack overflow/corruption detected for {}: stack=[{:#x}..{:#x}), expected magic={:#x}",
self.id_name(),
self.kstack.bottom().as_usize(),
self.kstack.top().as_usize(),
STACK_END_MAGIC
);
}
#[cfg(feature = "smp")]
#[inline]
pub(crate) fn set_cpu_id(&self, cpu_id: u32) {
self.cpu_id.store(cpu_id, Ordering::Release);
}
#[cfg(feature = "smp")]
#[inline]
pub(crate) fn on_cpu(&self) -> bool {
self.on_cpu.load(Ordering::Acquire)
}
#[cfg(feature = "smp")]
#[inline]
pub(crate) fn set_on_cpu(&self, on_cpu: bool) {
self.on_cpu.store(on_cpu, Ordering::Release)
}
}
impl fmt::Debug for TaskInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("TaskInner")
.field("id", &self.id)
.field("name", &self.name)
.field("state", &self.state())
.finish()
}
}
impl Drop for TaskInner {
fn drop(&mut self) {
debug!("task drop: {}", self.id_name());
}
}
pub(crate) struct TaskStack {
ptr: usize,
size: usize,
align: usize,
owned: bool,
}
impl TaskStack {
pub fn alloc(size: usize) -> Self {
let align = TASK_STACK_ALIGN;
let layout = Layout::from_size_align(size, align).unwrap();
let ptr = unsafe { alloc::alloc::alloc(layout) as usize };
assert_ne!(ptr, 0, "task stack allocation failed");
let stack = Self {
ptr,
size,
align,
owned: true,
};
#[cfg(feature = "stack-canary")]
unsafe {
stack.write_canary()
};
stack
}
pub fn borrowed(bottom: VirtAddr, size: usize, align: usize) -> Self {
assert_ne!(bottom.as_usize(), 0, "static task stack pointer is null");
let stack = Self {
ptr: bottom.as_usize(),
size,
align,
owned: false,
};
#[cfg(feature = "stack-canary")]
unsafe {
stack.write_canary()
};
stack
}
#[cfg(feature = "stack-canary")]
#[inline]
pub fn bottom(&self) -> VirtAddr {
VirtAddr::from(self.ptr)
}
#[inline]
pub fn top(&self) -> VirtAddr {
VirtAddr::from(self.ptr + self.size)
}
#[inline]
#[cfg(feature = "stack-canary")]
fn canary_ptr(&self) -> *mut usize {
self.ptr as *mut usize
}
#[inline]
#[cfg(feature = "stack-canary")]
unsafe fn write_canary(&self) {
unsafe { self.canary_ptr().write(STACK_END_MAGIC) };
}
#[inline]
#[cfg(feature = "stack-canary")]
pub fn is_canary_intact(&self) -> bool {
unsafe { self.canary_ptr().read() == STACK_END_MAGIC }
}
#[cfg(test)]
#[cfg(feature = "stack-canary")]
fn corrupt_canary_for_test(&self) {
unsafe { self.canary_ptr().write(0) };
}
}
impl Drop for TaskStack {
fn drop(&mut self) {
if self.owned {
let layout = Layout::from_size_align(self.size, self.align).unwrap();
unsafe { alloc::alloc::dealloc(self.ptr as *mut u8, layout) }
}
}
}
#[cfg(test)]
mod stack_tests {
use super::{TASK_STACK_ALIGN, TaskStack};
#[cfg(feature = "stack-canary")]
#[test]
fn task_stack_canary_detects_corruption() {
let stack = TaskStack::alloc(0x1000);
assert!(stack.is_canary_intact());
stack.corrupt_canary_for_test();
assert!(!stack.is_canary_intact());
}
#[cfg(feature = "stack-canary")]
#[cfg(target_arch = "x86_64")]
#[test]
fn task_stack_top_stays_16_byte_aligned() {
let stack = TaskStack::alloc(0x1000);
assert_eq!(stack.top().as_usize() % TASK_STACK_ALIGN, 0);
}
}
pub struct CurrentTask(ManuallyDrop<AxTaskRef>);
impl CurrentTask {
pub(crate) fn try_get() -> Option<Self> {
let ptr: *const super::AxTask = ax_hal::percpu::current_task_ptr();
if !ptr.is_null() {
Some(Self(unsafe { ManuallyDrop::new(AxTaskRef::from_raw(ptr)) }))
} else {
None
}
}
pub(crate) fn get() -> Self {
Self::try_get().expect("current task is uninitialized")
}
#[allow(clippy::should_implement_trait)]
pub fn clone(&self) -> AxTaskRef {
self.0.deref().clone()
}
pub fn ptr_eq(&self, other: &AxTaskRef) -> bool {
Arc::ptr_eq(&self.0, other)
}
pub(crate) unsafe fn init_current(init_task: AxTaskRef) {
assert!(init_task.is_init());
#[cfg(feature = "tls")]
unsafe {
ax_hal::asm::write_thread_pointer(init_task.tls.tls_ptr() as usize)
};
let ptr = Arc::into_raw(init_task);
unsafe {
ax_hal::percpu::set_current_task_ptr(ptr);
}
}
pub(crate) unsafe fn set_current(prev: Self, next: AxTaskRef) {
let Self(arc) = prev;
ManuallyDrop::into_inner(arc); let ptr = Arc::into_raw(next);
unsafe {
ax_hal::percpu::set_current_task_ptr(ptr);
}
}
}
impl Deref for CurrentTask {
type Target = AxTaskRef;
fn deref(&self) -> &Self::Target {
&self.0
}
}
extern "C" fn task_entry() -> ! {
#[cfg(feature = "smp")]
unsafe {
crate::run_queue::clear_prev_task_on_cpu();
}
#[cfg(feature = "irq")]
ax_hal::asm::enable_irqs();
let task = crate::current();
if let Some(entry) = task.entry.take() {
entry()
}
crate::exit(0);
}