#![allow(unsafe_code)]
use core::future::Future;
use core::pin::Pin;
use core::sync::atomic::Ordering;
use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use crate::memory_management::{FiberContext, FiberStatus};
static DTACT_WAKER_VTABLE: RawWakerVTable =
RawWakerVTable::new(clone_waker, wake_impl, wake_by_ref_impl, drop_waker);
#[inline(always)]
unsafe fn clone_waker(data: *const ()) -> RawWaker {
RawWaker::new(data, &DTACT_WAKER_VTABLE)
}
#[inline(always)]
unsafe fn wake_impl(data: *const ()) {
unsafe { wake_by_ref_impl(data) }
}
#[inline(always)]
unsafe fn wake_by_ref_impl(data: *const ()) {
let ctx = unsafe { &*data.cast::<FiberContext>() };
let prev = ctx
.state
.swap(FiberStatus::Notified as u32, Ordering::AcqRel);
if prev == FiberStatus::Yielded as u32 {
crate::wake_fiber(ctx.origin_core as usize, ctx.fiber_index);
}
}
#[inline(always)]
const unsafe fn drop_waker(_data: *const ()) {
}
#[inline(always)]
unsafe fn dtact_asm_fiber_suspend(ctx: *mut FiberContext) {
unsafe {
((*ctx).switch_fn)(&raw mut (*ctx).regs, &raw const (*ctx).executor_regs);
};
}
thread_local! {
pub(crate) static CURRENT_FIBER: core::cell::Cell<*mut FiberContext> = const { core::cell::Cell::new(core::ptr::null_mut()) };
pub(crate) static CURRENT_WORKER_ID: core::cell::Cell<usize> = const { core::cell::Cell::new(usize::MAX) };
}
#[inline(always)]
pub fn wait<F: Future>(mut fut: F) -> F::Output {
let ctx_ptr = CURRENT_FIBER.with(std::cell::Cell::get);
assert!(
!ctx_ptr.is_null(),
"dtact::wait() invoked outside of a DTA-V3 Fiber Execution Context. Thread migration forbidden."
);
let ctx = unsafe { &mut *ctx_ptr };
let tid = crate::utils::get_thread_id();
if ctx.last_os_thread_id == 0 {
ctx.last_os_thread_id = tid;
} else if ctx.last_os_thread_id != tid {
panic!(
"DTA-V3 Critical: Illegal OS Thread Migration detected for Fiber {}. Stack-pinned invariants violated.",
ctx.fiber_index
);
}
let _ = ctx;
let _ = tid;
let fut_pinned = unsafe { Pin::new_unchecked(&mut fut) };
wait_pinned(fut_pinned)
}
#[doc(hidden)]
#[inline(always)]
pub fn wait_pinned<F: Future>(mut fut_pinned: Pin<&mut F>) -> F::Output {
let ctx_ptr = CURRENT_FIBER.with(std::cell::Cell::get);
let ctx = unsafe { &mut *ctx_ptr };
let raw_waker = RawWaker::new(ctx_ptr as *const (), &DTACT_WAKER_VTABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
let mut cx = Context::from_waker(&waker);
loop {
let _ = ctx
.state
.swap(FiberStatus::Running as u32, Ordering::AcqRel);
match fut_pinned.as_mut().poll(&mut cx) {
Poll::Ready(output) => {
ctx.adaptive_spin_count = (ctx.adaptive_spin_count + 1).min(200);
ctx.spin_failure_count = ctx.spin_failure_count.saturating_sub(1);
return output;
}
Poll::Pending => {
let current_spin = ctx.adaptive_spin_count;
let failure_count = ctx.spin_failure_count;
if failure_count < 10 {
for i in 0..current_spin {
core::hint::spin_loop();
if (i.trailing_zeros() >= 3 || i == current_spin - 1)
&& let Poll::Ready(output) = fut_pinned.as_mut().poll(&mut cx)
{
ctx.adaptive_spin_count = (current_spin + 2).min(200);
ctx.spin_failure_count = failure_count.saturating_sub(1);
return output;
}
}
}
ctx.spin_failure_count = failure_count.saturating_add(1);
ctx.adaptive_spin_count = current_spin.saturating_sub(5).max(5);
if ctx
.state
.compare_exchange(
FiberStatus::Running as u32,
FiberStatus::Suspending as u32,
Ordering::Release,
Ordering::Acquire,
)
.is_ok()
{
unsafe { dtact_asm_fiber_suspend(ctx_ptr) };
}
}
}
}
}