use std::cell::RefCell;
use std::future::Future;
use std::time::Duration;
use tracing::{debug, warn};
use super::slot_token::SlotToken;
use crate::dal::DAL;
use crate::database::universal_types::UniversalUuid;
use crate::error::ExecutorError;
tokio::task_local! {
static TASK_HANDLE_SLOT: RefCell<Option<TaskHandle>>;
}
pub fn take_task_handle() -> TaskHandle {
TASK_HANDLE_SLOT.with(|cell| {
cell.borrow_mut()
.take()
.expect("TaskHandle not set in task-local storage — executor bug")
})
}
pub fn return_task_handle(handle: TaskHandle) {
TASK_HANDLE_SLOT.with(|cell| {
*cell.borrow_mut() = Some(handle);
})
}
pub async fn with_task_handle<F, T>(handle: TaskHandle, f: F) -> (T, Option<TaskHandle>)
where
F: Future<Output = T>,
{
TASK_HANDLE_SLOT
.scope(RefCell::new(Some(handle)), async {
let result = f.await;
let returned_handle = TASK_HANDLE_SLOT.with(|cell| cell.borrow_mut().take());
(result, returned_handle)
})
.await
}
pub struct TaskHandle {
slot_token: SlotToken,
task_execution_id: UniversalUuid,
dal: Option<DAL>,
cancel_rx: Option<tokio::sync::watch::Receiver<bool>>,
}
impl TaskHandle {
#[cfg(test)]
fn new(slot_token: SlotToken, task_execution_id: UniversalUuid) -> Self {
Self {
slot_token,
task_execution_id,
dal: None,
cancel_rx: None,
}
}
pub(crate) fn with_dal_and_cancel(
slot_token: SlotToken,
task_execution_id: UniversalUuid,
dal: DAL,
cancel_rx: tokio::sync::watch::Receiver<bool>,
) -> Self {
Self {
slot_token,
task_execution_id,
dal: Some(dal),
cancel_rx: Some(cancel_rx),
}
}
pub async fn defer_until<F, Fut>(
&mut self,
condition: F,
poll_interval: Duration,
) -> Result<(), ExecutorError>
where
F: Fn() -> Fut,
Fut: Future<Output = bool>,
{
debug!(
task_execution_id = %self.task_execution_id,
poll_interval_ms = poll_interval.as_millis(),
"Task entering deferred state — releasing concurrency slot"
);
if let Some(ref dal) = self.dal {
if let Err(e) = dal
.task_execution()
.set_sub_status(self.task_execution_id, Some("Deferred"))
.await
{
warn!(
task_execution_id = %self.task_execution_id,
error = %e,
"Failed to set sub_status to Deferred"
);
}
}
self.slot_token.release();
loop {
tokio::time::sleep(poll_interval).await;
if condition().await {
break;
}
}
self.slot_token.reclaim().await?;
if let Some(ref dal) = self.dal {
if let Err(e) = dal
.task_execution()
.set_sub_status(self.task_execution_id, Some("Active"))
.await
{
warn!(
task_execution_id = %self.task_execution_id,
error = %e,
"Failed to set sub_status back to Active"
);
}
}
debug!(
task_execution_id = %self.task_execution_id,
"Task resumed — concurrency slot reclaimed"
);
Ok(())
}
pub fn task_execution_id(&self) -> UniversalUuid {
self.task_execution_id
}
pub fn is_slot_held(&self) -> bool {
self.slot_token.is_held()
}
pub fn is_cancelled(&self) -> bool {
self.cancel_rx
.as_ref()
.map(|rx| *rx.borrow())
.unwrap_or(false)
}
pub async fn cancelled(&self) {
match self.cancel_rx.as_ref() {
Some(rx) => {
let mut rx = rx.clone();
let _ = rx.wait_for(|&v| v).await;
}
None => std::future::pending::<()>().await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::Semaphore;
fn make_handle(semaphore: &Arc<Semaphore>) -> TaskHandle {
let permit = semaphore
.clone()
.try_acquire_owned()
.expect("permit should be available");
let slot_token = SlotToken::new(permit, semaphore.clone());
TaskHandle::new(slot_token, UniversalUuid::new_v4())
}
#[tokio::test]
async fn test_defer_until_releases_and_reclaims_slot() {
let semaphore = Arc::new(Semaphore::new(1));
let mut handle = make_handle(&semaphore);
assert_eq!(semaphore.available_permits(), 0);
let call_count = Arc::new(AtomicUsize::new(0));
let cc = call_count.clone();
handle
.defer_until(
move || {
let cc = cc.clone();
async move {
let count = cc.fetch_add(1, Ordering::SeqCst);
count >= 2 }
},
Duration::from_millis(1),
)
.await
.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 3);
assert!(handle.is_slot_held());
assert_eq!(semaphore.available_permits(), 0);
}
#[tokio::test]
async fn test_defer_until_immediate_condition() {
let semaphore = Arc::new(Semaphore::new(1));
let mut handle = make_handle(&semaphore);
handle
.defer_until(|| async { true }, Duration::from_millis(1))
.await
.unwrap();
assert!(handle.is_slot_held());
}
#[tokio::test]
async fn test_defer_until_frees_slot_for_other_tasks() {
let semaphore = Arc::new(Semaphore::new(1));
let mut handle = make_handle(&semaphore);
assert_eq!(semaphore.available_permits(), 0);
let sem_clone = semaphore.clone();
let slot_was_available = Arc::new(std::sync::atomic::AtomicBool::new(false));
let swa = slot_was_available.clone();
handle
.defer_until(
move || {
let swa = swa.clone();
let sem = sem_clone.clone();
async move {
if sem.available_permits() > 0 {
swa.store(true, Ordering::SeqCst);
}
true }
},
Duration::from_millis(1),
)
.await
.unwrap();
assert!(slot_was_available.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_task_local_round_trip() {
let semaphore = Arc::new(Semaphore::new(1));
let handle = make_handle(&semaphore);
let original_id = handle.task_execution_id();
let (result, returned_handle) = with_task_handle(handle, async {
let taken = take_task_handle();
assert_eq!(taken.task_execution_id(), original_id);
assert!(taken.is_slot_held());
return_task_handle(taken);
42
})
.await;
assert_eq!(result, 42);
let rh = returned_handle.expect("handle should be returned");
assert_eq!(rh.task_execution_id(), original_id);
assert!(rh.is_slot_held());
}
#[tokio::test]
async fn test_task_local_not_returned_yields_none() {
let semaphore = Arc::new(Semaphore::new(1));
let handle = make_handle(&semaphore);
let (_result, returned_handle) = with_task_handle(handle, async {
let _taken = take_task_handle();
})
.await;
assert!(
returned_handle.is_none(),
"handle should be None when not returned"
);
}
#[tokio::test]
async fn test_is_cancelled_default_false_without_channel() {
let semaphore = Arc::new(Semaphore::new(1));
let handle = make_handle(&semaphore);
assert!(!handle.is_cancelled());
let cancelled_fires =
tokio::time::timeout(Duration::from_millis(20), handle.cancelled()).await;
assert!(
cancelled_fires.is_err(),
"cancelled() must never resolve without a cancel channel"
);
}
#[tokio::test]
async fn test_is_cancelled_reflects_watch_value() {
let semaphore = Arc::new(Semaphore::new(1));
let permit = semaphore
.clone()
.try_acquire_owned()
.expect("permit available");
let slot_token = SlotToken::new(permit, semaphore.clone());
let (tx, rx) = tokio::sync::watch::channel(false);
let handle = TaskHandle {
slot_token,
task_execution_id: UniversalUuid::new_v4(),
dal: None,
cancel_rx: Some(rx),
};
assert!(!handle.is_cancelled(), "no signal → not cancelled");
tx.send(true).expect("send cancellation");
assert!(handle.is_cancelled(), "after send(true) → cancelled");
}
#[tokio::test]
async fn test_cancelled_future_resolves_after_signal() {
let semaphore = Arc::new(Semaphore::new(1));
let permit = semaphore
.clone()
.try_acquire_owned()
.expect("permit available");
let slot_token = SlotToken::new(permit, semaphore.clone());
let (tx, rx) = tokio::sync::watch::channel(false);
let handle = TaskHandle {
slot_token,
task_execution_id: UniversalUuid::new_v4(),
dal: None,
cancel_rx: Some(rx),
};
let early = tokio::time::timeout(Duration::from_millis(10), handle.cancelled()).await;
assert!(early.is_err(), "cancelled() must not resolve before send");
tx.send(true).expect("send cancellation");
let after = tokio::time::timeout(Duration::from_millis(200), handle.cancelled()).await;
assert!(after.is_ok(), "cancelled() must resolve after send(true)");
}
#[tokio::test]
async fn test_cancelled_future_does_not_fire_when_sender_dropped() {
let semaphore = Arc::new(Semaphore::new(1));
let permit = semaphore
.clone()
.try_acquire_owned()
.expect("permit available");
let slot_token = SlotToken::new(permit, semaphore.clone());
let (tx, rx) = tokio::sync::watch::channel(false);
let handle = TaskHandle {
slot_token,
task_execution_id: UniversalUuid::new_v4(),
dal: None,
cancel_rx: Some(rx),
};
drop(tx);
let elapsed = tokio::time::timeout(Duration::from_millis(20), handle.cancelled()).await;
assert!(
elapsed.is_ok(),
"cancelled() resolves when sender drops (documented behavior)"
);
assert!(
!handle.is_cancelled(),
"is_cancelled() is the source of truth — stays false on sender drop"
);
}
#[tokio::test]
async fn test_with_task_handle_preserves_handle_through_defer() {
let semaphore = Arc::new(Semaphore::new(1));
let handle = make_handle(&semaphore);
let original_id = handle.task_execution_id();
let (_result, returned_handle) = with_task_handle(handle, async {
let mut taken = take_task_handle();
taken
.defer_until(|| async { true }, Duration::from_millis(1))
.await
.unwrap();
assert!(taken.is_slot_held());
return_task_handle(taken);
})
.await;
let rh = returned_handle.expect("handle should survive defer_until");
assert_eq!(rh.task_execution_id(), original_id);
assert!(rh.is_slot_held());
}
}