use crate::types::{Budget, CancelReason, RegionId, TaskId, Time};
use std::task::Waker;
pub const MAX_MASK_DEPTH: u32 = 64;
#[derive(Debug, Clone)]
pub struct CheckpointState {
pub last_checkpoint: Option<Time>,
pub last_message: Option<String>,
pub checkpoint_count: u64,
}
impl Default for CheckpointState {
fn default() -> Self {
Self::new()
}
}
impl CheckpointState {
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
last_checkpoint: None,
last_message: None,
checkpoint_count: 0,
}
}
#[inline]
pub fn record(&mut self) {
self.record_at(crate::time::wall_now());
}
#[inline]
pub fn record_at(&mut self, at: Time) {
self.last_checkpoint = Some(at);
self.last_message = None;
self.checkpoint_count += 1;
}
#[inline]
pub fn record_with_message(&mut self, message: String) {
self.record_with_message_at(message, crate::time::wall_now());
}
#[inline]
pub fn record_with_message_at(&mut self, message: String, at: Time) {
self.last_checkpoint = Some(at);
self.last_message = Some(message);
self.checkpoint_count += 1;
}
}
#[derive(Debug)]
pub struct CxInner {
pub region: RegionId,
pub task: TaskId,
pub task_type: Option<String>,
pub budget: Budget,
pub budget_baseline: Budget,
pub cancel_requested: bool,
pub cancel_reason: Option<CancelReason>,
pub cancel_acknowledged: bool,
pub cancel_waker: Option<Waker>,
pub mask_depth: u32,
pub checkpoint_state: CheckpointState,
pub fast_cancel: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
impl CxInner {
#[must_use]
pub fn new(region: RegionId, task: TaskId, budget: Budget) -> Self {
Self {
region,
task,
task_type: None,
budget,
budget_baseline: budget,
cancel_requested: false,
cancel_reason: None,
cancel_acknowledged: false,
cancel_waker: None,
mask_depth: 0,
checkpoint_state: CheckpointState::new(),
fast_cancel: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
#[test]
fn test_checkpoint_state_default() {
init_test("test_checkpoint_state_default");
let state = CheckpointState::new();
crate::assert_with_log!(
state.last_checkpoint.is_none(),
"last_checkpoint",
true,
state.last_checkpoint.is_none()
);
crate::assert_with_log!(
state.last_message.is_none(),
"last_message",
true,
state.last_message.is_none()
);
crate::assert_with_log!(
state.checkpoint_count == 0,
"checkpoint_count",
0,
state.checkpoint_count
);
crate::test_complete!("test_checkpoint_state_default");
}
#[test]
fn test_checkpoint_state_record() {
init_test("test_checkpoint_state_record");
let mut state = CheckpointState::new();
state.record();
crate::assert_with_log!(
state.last_checkpoint.is_some(),
"last_checkpoint",
true,
state.last_checkpoint.is_some()
);
crate::assert_with_log!(
state.last_message.is_none(),
"last_message",
true,
state.last_message.is_none()
);
crate::assert_with_log!(
state.checkpoint_count == 1,
"checkpoint_count",
1,
state.checkpoint_count
);
state.record();
crate::assert_with_log!(
state.checkpoint_count == 2,
"checkpoint_count 2",
2,
state.checkpoint_count
);
crate::test_complete!("test_checkpoint_state_record");
}
#[test]
fn test_checkpoint_state_record_at() {
init_test("test_checkpoint_state_record_at");
let mut state = CheckpointState::new();
let at = Time::from_nanos(123);
state.record_at(at);
crate::assert_with_log!(
state.last_checkpoint == Some(at),
"explicit checkpoint instant stored",
format!("{at:?}"),
format!("{:?}", state.last_checkpoint)
);
crate::assert_with_log!(
state.last_message.is_none(),
"record_at clears message",
true,
state.last_message.is_none()
);
crate::assert_with_log!(
state.checkpoint_count == 1,
"record_at increments count",
1,
state.checkpoint_count
);
crate::test_complete!("test_checkpoint_state_record_at");
}
#[test]
fn test_checkpoint_state_record_with_message() {
init_test("test_checkpoint_state_record_with_message");
let mut state = CheckpointState::new();
state.record_with_message("hello".to_string());
crate::assert_with_log!(
state.last_checkpoint.is_some(),
"last_checkpoint",
true,
state.last_checkpoint.is_some()
);
crate::assert_with_log!(
state.last_message.as_deref() == Some("hello"),
"last_message",
Some("hello"),
state.last_message.as_deref()
);
crate::assert_with_log!(
state.checkpoint_count == 1,
"checkpoint_count",
1,
state.checkpoint_count
);
state.record();
crate::assert_with_log!(
state.last_message.is_none(),
"last_message cleared",
true,
state.last_message.is_none()
);
crate::test_complete!("test_checkpoint_state_record_with_message");
}
#[test]
fn test_checkpoint_state_record_with_message_at() {
init_test("test_checkpoint_state_record_with_message_at");
let mut state = CheckpointState::new();
let at = Time::from_nanos(456);
state.record_with_message_at("hello".to_string(), at);
crate::assert_with_log!(
state.last_checkpoint == Some(at),
"explicit checkpoint instant stored",
format!("{at:?}"),
format!("{:?}", state.last_checkpoint)
);
crate::assert_with_log!(
state.last_message.as_deref() == Some("hello"),
"record_with_message_at stores message",
Some("hello"),
state.last_message.as_deref()
);
crate::assert_with_log!(
state.checkpoint_count == 1,
"record_with_message_at increments count",
1,
state.checkpoint_count
);
crate::test_complete!("test_checkpoint_state_record_with_message_at");
}
#[test]
fn test_checkpoint_state_message_overwrite() {
init_test("test_checkpoint_state_message_overwrite");
let mut state = CheckpointState::new();
state.record_with_message("first".to_string());
state.record_with_message("second".to_string());
crate::assert_with_log!(
state.last_message.as_deref() == Some("second"),
"last_message overwrite",
Some("second"),
state.last_message.as_deref()
);
crate::assert_with_log!(
state.checkpoint_count == 2,
"checkpoint_count",
2,
state.checkpoint_count
);
crate::test_complete!("test_checkpoint_state_message_overwrite");
}
#[test]
fn test_cx_inner_new() {
init_test("test_cx_inner_new");
let region = RegionId::testing_default();
let task = TaskId::testing_default();
let budget = Budget::new();
let cx = CxInner::new(region, task, budget);
crate::assert_with_log!(cx.region == region, "region", region, cx.region);
crate::assert_with_log!(cx.task == task, "task", task, cx.task);
crate::assert_with_log!(cx.budget == budget, "budget", budget, cx.budget);
crate::assert_with_log!(
cx.budget_baseline == budget,
"budget_baseline",
budget,
cx.budget_baseline
);
crate::assert_with_log!(
!cx.cancel_requested,
"cancel_requested",
false,
cx.cancel_requested
);
crate::assert_with_log!(
cx.cancel_reason.is_none(),
"cancel_reason",
true,
cx.cancel_reason.is_none()
);
crate::assert_with_log!(cx.mask_depth == 0, "mask_depth", 0, cx.mask_depth);
crate::test_complete!("test_cx_inner_new");
}
#[test]
fn checkpoint_state_debug_clone_default() {
let def = CheckpointState::default();
assert!(def.last_checkpoint.is_none());
assert!(def.last_message.is_none());
assert_eq!(def.checkpoint_count, 0);
let dbg = format!("{def:?}");
assert!(dbg.contains("CheckpointState"), "{dbg}");
let mut state = CheckpointState::new();
state.record_with_message("progress".into());
let cloned = state.clone();
assert_eq!(cloned.checkpoint_count, 1);
assert_eq!(cloned.last_message.as_deref(), Some("progress"));
}
#[test]
fn cx_inner_debug() {
let region = RegionId::testing_default();
let task = TaskId::testing_default();
let cx = CxInner::new(region, task, Budget::new());
let dbg = format!("{cx:?}");
assert!(dbg.contains("CxInner"), "{dbg}");
}
}