use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicU64;
use crate::loom::sync::atomic::Ordering;
use crate::runtime::scheduler;
use crate::sync::AtomicWaker;
use crate::time::Instant;
use crate::util::linked_list;
use pin_project_lite::pin_project;
use std::task::{Context, Poll, Waker};
use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull};
type TimerResult = Result<(), crate::time::error::Error>;
pub(in crate::runtime::time) const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE;
pub(super) const MAX_SAFE_MILLIS_DURATION: u64 = STATE_MIN_VALUE - 1;
pub(super) struct StateCell {
state: AtomicU64,
result: UnsafeCell<TimerResult>,
waker: AtomicWaker,
}
impl Default for StateCell {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for StateCell {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "StateCell({:?})", self.read_state())
}
}
impl StateCell {
fn new() -> Self {
Self {
state: AtomicU64::new(STATE_DEREGISTERED),
result: UnsafeCell::new(Ok(())),
waker: AtomicWaker::new(),
}
}
fn is_pending(&self) -> bool {
self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE
}
fn when(&self) -> Option<u64> {
let cur_state = self.state.load(Ordering::Relaxed);
if cur_state == STATE_DEREGISTERED {
None
} else {
Some(cur_state)
}
}
fn poll(&self, waker: &Waker) -> Poll<TimerResult> {
self.waker.register_by_ref(waker);
self.read_state()
}
fn read_state(&self) -> Poll<TimerResult> {
let cur_state = self.state.load(Ordering::Acquire);
if cur_state == STATE_DEREGISTERED {
Poll::Ready(unsafe { self.result.with(|p| *p) })
} else {
Poll::Pending
}
}
unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
let mut cur_state = self.state.load(Ordering::Relaxed);
loop {
assert!(
cur_state < STATE_MIN_VALUE,
"mark_pending called when the timer entry is in an invalid state"
);
if cur_state > not_after {
break Err(cur_state);
}
match self.state.compare_exchange_weak(
cur_state,
STATE_PENDING_FIRE,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break Ok(()),
Err(actual_state) => cur_state = actual_state,
}
}
}
unsafe fn fire(&self, result: TimerResult) -> Option<Waker> {
let cur_state = self.state.load(Ordering::Relaxed);
if cur_state == STATE_DEREGISTERED {
return None;
}
unsafe { self.result.with_mut(|p| *p = result) };
self.state.store(STATE_DEREGISTERED, Ordering::Release);
self.waker.take_waker()
}
fn set_expiration(&self, timestamp: u64) {
debug_assert!(timestamp < STATE_MIN_VALUE);
self.state.store(timestamp, Ordering::Relaxed);
}
fn extend_expiration(&self, new_timestamp: u64) -> Result<(), ()> {
let mut prior = self.state.load(Ordering::Relaxed);
loop {
if new_timestamp < prior || prior >= STATE_MIN_VALUE {
return Err(());
}
match self.state.compare_exchange_weak(
prior,
new_timestamp,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Ok(()),
Err(true_prior) => prior = true_prior,
}
}
}
pub(super) fn might_be_registered(&self) -> bool {
self.state.load(Ordering::Relaxed) != STATE_DEREGISTERED
}
}
pin_project! {
#[derive(Debug)]
pub(crate) struct TimerEntry {
driver: scheduler::Handle,
#[pin]
inner: Option<TimerShared>,
deadline: Instant,
registered: bool,
}
impl PinnedDrop for TimerEntry {
fn drop(this: Pin<&mut Self>) {
this.cancel();
}
}
}
unsafe impl Send for TimerEntry {}
unsafe impl Sync for TimerEntry {}
#[derive(Debug)]
pub(crate) struct TimerHandle {
inner: NonNull<TimerShared>,
}
pub(super) type EntryList = crate::util::linked_list::LinkedList<TimerShared, TimerShared>;
pub(crate) struct TimerShared {
pointers: linked_list::Pointers<TimerShared>,
registered_when: AtomicU64,
state: StateCell,
_p: PhantomPinned,
}
unsafe impl Send for TimerShared {}
unsafe impl Sync for TimerShared {}
impl std::fmt::Debug for TimerShared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimerShared")
.field(
"registered_when",
&self.registered_when.load(Ordering::Relaxed),
)
.field("state", &self.state)
.finish()
}
}
generate_addr_of_methods! {
impl<> TimerShared {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<TimerShared>> {
&self.pointers
}
}
}
impl TimerShared {
pub(super) fn new() -> Self {
Self {
registered_when: AtomicU64::new(0),
pointers: linked_list::Pointers::new(),
state: StateCell::default(),
_p: PhantomPinned,
}
}
pub(super) fn registered_when(&self) -> u64 {
self.registered_when.load(Ordering::Relaxed)
}
pub(super) unsafe fn sync_when(&self) -> u64 {
let true_when = self.true_when();
self.registered_when.store(true_when, Ordering::Relaxed);
true_when
}
unsafe fn set_registered_when(&self, when: u64) {
self.registered_when.store(when, Ordering::Relaxed);
}
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
}
pub(super) unsafe fn set_expiration(&self, t: u64) {
self.state.set_expiration(t);
self.registered_when.store(t, Ordering::Relaxed);
}
pub(super) fn extend_expiration(&self, t: u64) -> Result<(), ()> {
self.state.extend_expiration(t)
}
pub(super) fn handle(&self) -> TimerHandle {
TimerHandle {
inner: NonNull::from(self),
}
}
pub(super) fn might_be_registered(&self) -> bool {
self.state.might_be_registered()
}
}
unsafe impl linked_list::Link for TimerShared {
type Handle = TimerHandle;
type Target = TimerShared;
fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target> {
handle.inner
}
unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle {
TimerHandle { inner: ptr }
}
unsafe fn pointers(
target: NonNull<Self::Target>,
) -> NonNull<linked_list::Pointers<Self::Target>> {
TimerShared::addr_of_pointers(target)
}
}
impl TimerEntry {
#[track_caller]
pub(crate) fn new(handle: scheduler::Handle, deadline: Instant) -> Self {
let _ = handle.driver().time();
Self {
driver: handle,
inner: None,
deadline,
registered: false,
}
}
fn inner(&self) -> Option<&TimerShared> {
self.inner.as_ref()
}
fn init_inner(self: Pin<&mut Self>) {
match self.inner {
Some(_) => {}
None => self.project().inner.set(Some(TimerShared::new())),
}
}
pub(crate) fn deadline(&self) -> Instant {
self.deadline
}
pub(crate) fn is_elapsed(&self) -> bool {
let Some(inner) = self.inner() else {
return false;
};
let deregistered = !inner.might_be_registered();
deregistered && self.registered
}
pub(crate) fn cancel(self: Pin<&mut Self>) {
let Some(inner) = self.inner() else {
return;
};
unsafe { self.driver().clear_entry(NonNull::from(inner)) };
}
pub(crate) fn reset(mut self: Pin<&mut Self>, new_time: Instant, reregister: bool) {
let this = self.as_mut().project();
*this.deadline = new_time;
*this.registered = reregister;
let tick = self.driver().time_source().deadline_to_tick(new_time);
let inner = match self.inner() {
Some(inner) => inner,
None => {
self.as_mut().init_inner();
self.inner()
.expect("inner should already be initialized by `this.init_inner()`")
}
};
if inner.extend_expiration(tick).is_ok() {
return;
}
if reregister {
unsafe {
self.driver()
.reregister(&self.driver.driver().io, tick, inner.into());
}
}
}
pub(crate) fn poll_elapsed(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), super::Error>> {
assert!(
!self.driver().is_shutdown(),
"{}",
crate::util::error::RUNTIME_SHUTTING_DOWN_ERROR
);
if !self.registered {
let deadline = self.deadline;
self.as_mut().reset(deadline, true);
}
let inner = self
.inner()
.expect("inner should already be initialized by `self.reset()`");
inner.state.poll(cx.waker())
}
pub(crate) fn driver(&self) -> &super::Handle {
self.driver.driver().time()
}
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(crate) fn clock(&self) -> &super::Clock {
self.driver.driver().clock()
}
}
impl TimerHandle {
pub(super) unsafe fn registered_when(&self) -> u64 {
unsafe { self.inner.as_ref().registered_when() }
}
pub(super) unsafe fn sync_when(&self) -> u64 {
unsafe { self.inner.as_ref().sync_when() }
}
pub(super) unsafe fn is_pending(&self) -> bool {
unsafe { self.inner.as_ref().state.is_pending() }
}
pub(super) unsafe fn set_expiration(&self, tick: u64) {
self.inner.as_ref().set_expiration(tick);
}
pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
match self.inner.as_ref().state.mark_pending(not_after) {
Ok(()) => {
self.inner.as_ref().set_registered_when(STATE_DEREGISTERED);
Ok(())
}
Err(tick) => {
self.inner.as_ref().set_registered_when(tick);
Err(tick)
}
}
}
pub(super) unsafe fn fire(self, completed_state: TimerResult) -> Option<Waker> {
self.inner.as_ref().state.fire(completed_state)
}
}