use crate::tracing_compat::trace;
use crate::types::{Outcome, TaskId};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct StoredTask {
future: Pin<Box<dyn Future<Output = Outcome<(), ()>> + Send>>,
task_id: Option<TaskId>,
poll_count: u64,
polls_remaining: Option<u32>,
}
impl StoredTask {
#[inline]
pub fn new<F>(future: F) -> Self
where
F: Future<Output = Outcome<(), ()>> + Send + 'static,
{
Self {
future: Box::pin(future),
task_id: None,
poll_count: 0,
polls_remaining: None,
}
}
#[inline]
pub fn new_with_id<F>(future: F, task_id: TaskId) -> Self
where
F: Future<Output = Outcome<(), ()>> + Send + 'static,
{
Self {
future: Box::pin(future),
task_id: Some(task_id),
poll_count: 0,
polls_remaining: None,
}
}
#[inline]
pub fn set_task_id(&mut self, task_id: TaskId) {
self.task_id = Some(task_id);
}
#[inline]
pub fn set_polls_remaining(&mut self, remaining: u32) {
self.polls_remaining = Some(remaining);
}
#[inline]
#[allow(clippy::used_underscore_binding)]
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Outcome<(), ()>> {
self.poll_count += 1;
let poll_number = self.poll_count;
let budget_remaining = self.polls_remaining.take().unwrap_or(0);
if let Some(task_id) = self.task_id {
trace!(
task_id = ?task_id,
poll_number = poll_number,
budget_remaining = budget_remaining,
"task poll started"
);
let _ = (task_id, poll_number, budget_remaining);
}
let result = self.future.as_mut().poll(cx);
if let Some(task_id) = self.task_id {
let poll_result = match &result {
Poll::Ready(_) => "Ready",
Poll::Pending => "Pending",
};
trace!(
task_id = ?task_id,
poll_number = poll_number,
poll_result = poll_result,
"task poll completed"
);
let _ = (task_id, poll_number, poll_result);
}
result
}
#[inline]
#[must_use]
pub fn poll_count(&self) -> u64 {
self.poll_count
}
}
impl std::fmt::Debug for StoredTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StoredTask").finish_non_exhaustive()
}
}
pub struct LocalStoredTask {
future: Pin<Box<dyn Future<Output = Outcome<(), ()>> + 'static>>,
task_id: Option<TaskId>,
poll_count: u64,
polls_remaining: Option<u32>,
}
impl LocalStoredTask {
#[inline]
pub fn new<F>(future: F) -> Self
where
F: Future<Output = Outcome<(), ()>> + 'static,
{
Self {
future: Box::pin(future),
task_id: None,
poll_count: 0,
polls_remaining: None,
}
}
#[inline]
pub fn new_with_id<F>(future: F, task_id: TaskId) -> Self
where
F: Future<Output = Outcome<(), ()>> + 'static,
{
Self {
future: Box::pin(future),
task_id: Some(task_id),
poll_count: 0,
polls_remaining: None,
}
}
#[inline]
pub fn set_task_id(&mut self, task_id: TaskId) {
self.task_id = Some(task_id);
}
#[inline]
#[must_use]
pub fn task_id(&self) -> Option<TaskId> {
self.task_id
}
#[inline]
pub fn set_polls_remaining(&mut self, remaining: u32) {
self.polls_remaining = Some(remaining);
}
#[inline]
#[allow(clippy::used_underscore_binding)]
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Outcome<(), ()>> {
self.poll_count += 1;
let poll_number = self.poll_count;
let budget_remaining = self.polls_remaining.take().unwrap_or(0);
if let Some(task_id) = self.task_id {
trace!(
task_id = ?task_id,
poll_number = poll_number,
budget_remaining = budget_remaining,
"local task poll started"
);
let _ = (task_id, poll_number, budget_remaining);
}
let result = self.future.as_mut().poll(cx);
if let Some(task_id) = self.task_id {
let poll_result = match &result {
Poll::Ready(_) => "Ready",
Poll::Pending => "Pending",
};
trace!(
task_id = ?task_id,
poll_number = poll_number,
poll_result = poll_result,
"local task poll completed"
);
let _ = (task_id, poll_number, poll_result);
}
result
}
}
impl std::fmt::Debug for LocalStoredTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalStoredTask").finish_non_exhaustive()
}
}
#[derive(Debug)]
pub enum AnyStoredTask {
Global(StoredTask),
Local(LocalStoredTask),
}
impl AnyStoredTask {
#[inline]
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Outcome<(), ()>> {
match self {
Self::Global(t) => t.poll(cx),
Self::Local(t) => t.poll(cx),
}
}
#[inline]
#[must_use]
pub fn is_local(&self) -> bool {
matches!(self, Self::Local(_))
}
#[inline]
pub fn set_polls_remaining(&mut self, remaining: u32) {
match self {
Self::Global(t) => t.set_polls_remaining(remaining),
Self::Local(t) => t.set_polls_remaining(remaining),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::init_test_logging;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll, Waker};
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
fn init_test(test_name: &str) {
init_test_logging();
crate::test_phase!(test_name);
}
#[test]
fn stored_task_polls_to_completion() {
init_test("stored_task_polls_to_completion");
let completed = Arc::new(AtomicBool::new(false));
let completed_clone = completed.clone();
let task = StoredTask::new(async move {
completed_clone.store(true, Ordering::SeqCst);
Outcome::Ok(())
});
let mut task = task;
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
crate::test_section!("poll");
let result = task.poll(&mut cx);
let ready = matches!(result, Poll::Ready(Outcome::Ok(())));
crate::assert_with_log!(ready, "poll should complete immediately", true, ready);
let completed_value = completed.load(Ordering::SeqCst);
crate::assert_with_log!(
completed_value,
"completion flag should be set",
true,
completed_value
);
crate::test_complete!("stored_task_polls_to_completion");
}
#[test]
fn stored_task_debug() {
init_test("stored_task_debug");
let task = StoredTask::new(async { Outcome::Ok(()) });
let debug = format!("{task:?}");
let contains = debug.contains("StoredTask");
crate::assert_with_log!(
contains,
"debug output should mention StoredTask",
true,
contains
);
crate::test_complete!("stored_task_debug");
}
#[test]
fn any_stored_task_is_local_global() {
init_test("any_stored_task_is_local_global");
let task = AnyStoredTask::Global(StoredTask::new(async { Outcome::Ok(()) }));
let local = task.is_local();
crate::assert_with_log!(!local, "Global variant must not be local", false, local);
crate::test_complete!("any_stored_task_is_local_global");
}
#[test]
fn any_stored_task_is_local_local() {
init_test("any_stored_task_is_local_local");
let task = AnyStoredTask::Local(LocalStoredTask::new(async { Outcome::Ok(()) }));
let local = task.is_local();
crate::assert_with_log!(local, "Local variant must be local", true, local);
crate::test_complete!("any_stored_task_is_local_local");
}
#[test]
fn any_stored_task_is_local_stable_after_poll() {
init_test("any_stored_task_is_local_stable_after_poll");
let mut task = AnyStoredTask::Local(LocalStoredTask::new(async { Outcome::Ok(()) }));
let before = task.is_local();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let _ = task.poll(&mut cx);
let after = task.is_local();
crate::assert_with_log!(
before == after,
"is_local must be stable across poll",
true,
before == after
);
crate::test_complete!("any_stored_task_is_local_stable_after_poll");
}
#[test]
fn stored_task_consumes_polls_remaining_after_poll() {
init_test("stored_task_consumes_polls_remaining_after_poll");
let mut task = StoredTask::new(async { Outcome::Ok(()) });
task.set_polls_remaining(7);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let _ = task.poll(&mut cx);
crate::assert_with_log!(
task.polls_remaining.is_none(),
"polls_remaining should be consumed by poll",
true,
task.polls_remaining.is_none()
);
crate::test_complete!("stored_task_consumes_polls_remaining_after_poll");
}
#[test]
fn local_stored_task_consumes_polls_remaining_after_poll() {
init_test("local_stored_task_consumes_polls_remaining_after_poll");
let mut task = LocalStoredTask::new(async { Outcome::Ok(()) });
task.set_polls_remaining(11);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let _ = task.poll(&mut cx);
crate::assert_with_log!(
task.polls_remaining.is_none(),
"polls_remaining should be consumed by poll for local tasks",
true,
task.polls_remaining.is_none()
);
crate::test_complete!("local_stored_task_consumes_polls_remaining_after_poll");
}
}