use std::sync::{
Arc,
atomic::{
AtomicU8,
Ordering,
},
};
use crate::tokio_executor_service_state::TokioExecutorServiceState;
const TASK_STATE_QUEUED: u8 = 0;
const TASK_STATE_RUNNING: u8 = 1;
const TASK_STATE_FINISHED: u8 = 2;
struct TokioServiceTaskTracker {
state: Arc<TokioExecutorServiceState>,
marker: Arc<()>,
task_state: AtomicU8,
}
impl TokioServiceTaskTracker {
pub(crate) fn new(state: Arc<TokioExecutorServiceState>, marker: Arc<()>) -> Self {
Self {
state,
marker,
task_state: AtomicU8::new(TASK_STATE_QUEUED),
}
}
#[inline]
pub(crate) fn marker(&self) -> &Arc<()> {
&self.marker
}
pub(crate) fn mark_started(&self) -> bool {
match self.task_state.compare_exchange(
TASK_STATE_QUEUED,
TASK_STATE_RUNNING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
self.state.mark_task_started();
true
}
Err(TASK_STATE_RUNNING) => true,
Err(_) => false,
}
}
pub(crate) fn finish_queued(&self) -> bool {
match self.task_state.compare_exchange(
TASK_STATE_QUEUED,
TASK_STATE_FINISHED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
self.state.finish_task(false);
true
}
Err(_) => false,
}
}
pub(crate) fn finish(&self) {
match self.task_state.swap(TASK_STATE_FINISHED, Ordering::AcqRel) {
TASK_STATE_QUEUED => self.state.finish_task(false),
TASK_STATE_RUNNING => self.state.finish_task(true),
_ => {}
}
}
}
pub(crate) struct TokioServiceTaskGuard {
tracker: Arc<TokioServiceTaskTracker>,
}
impl TokioServiceTaskGuard {
pub(crate) fn new(state: Arc<TokioExecutorServiceState>, marker: Arc<()>) -> Self {
Self {
tracker: Arc::new(TokioServiceTaskTracker::new(state, marker)),
}
}
pub(crate) fn mark_started(&self) -> bool {
self.tracker.mark_started()
}
pub(crate) fn finish_queued_once_callback(&self) -> impl FnOnce() + Send + 'static {
let tracker = Arc::clone(&self.tracker);
move || {
tracker.finish_queued();
}
}
pub(crate) fn finish_queued_callback(&self) -> impl Fn() + Send + Sync + 'static {
let tracker = Arc::clone(&self.tracker);
move || {
tracker.finish_queued();
}
}
}
impl Drop for TokioServiceTaskGuard {
fn drop(&mut self) {
self.tracker
.state
.remove_abort_handle(self.tracker.marker());
self.tracker.finish();
}
}