use std::sync::Mutex;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use bytes::Bytes;
use crate::runtime::HandlerResult;
use super::TestError;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[non_exhaustive]
pub enum Outcome {
Ack,
Nack,
Drop,
DecodeFailed,
Panicked,
}
pub(crate) struct Record {
pub(crate) scope_id: usize,
pub(crate) name: String,
pub(crate) raw: Bytes,
pub(crate) settle: Option<HandlerResult>,
pub(crate) panicked: bool,
pub(crate) decode_failed: bool,
}
impl Record {
pub(crate) fn outcome(&self) -> Outcome {
if self.panicked {
Outcome::Panicked
} else if self.decode_failed {
Outcome::DecodeFailed
} else {
match self.settle {
Some(HandlerResult::Ack) => Outcome::Ack,
Some(HandlerResult::Nack { requeue: true } | HandlerResult::NackAfter { .. }) => {
Outcome::Nack
}
Some(HandlerResult::Nack { requeue: false }) | None => Outcome::Drop,
}
}
}
}
pub(crate) struct TestHooks {
coordinator: OnceLock<Coordinator>,
}
impl TestHooks {
pub(crate) fn detached() -> Self {
Self {
coordinator: OnceLock::new(),
}
}
pub(crate) fn install(&self, coordinator: Coordinator) {
let _ = self.coordinator.set(coordinator);
}
pub(crate) fn coordinator(&self) -> Option<&Coordinator> {
self.coordinator.get()
}
}
#[derive(Clone)]
pub struct Coordinator {
inner: std::sync::Arc<Inner>,
}
impl std::fmt::Debug for Coordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Coordinator")
.field("in_flight", &self.inner.in_flight.load(Ordering::SeqCst))
.field("processed", &self.inner.processed.load(Ordering::SeqCst))
.finish_non_exhaustive()
}
}
struct Inner {
in_flight: AtomicUsize,
processed: AtomicUsize,
max_steps: usize,
notify: tokio::sync::Notify,
records: Mutex<Vec<Record>>,
timers: Mutex<Vec<Timer>>,
}
struct Timer {
deadline: tokio::time::Instant,
handle: tokio::task::JoinHandle<()>,
}
impl Coordinator {
pub(crate) fn new(max_steps: usize) -> Self {
Self {
inner: std::sync::Arc::new(Inner {
in_flight: AtomicUsize::new(0),
processed: AtomicUsize::new(0),
max_steps,
notify: tokio::sync::Notify::new(),
records: Mutex::new(Vec::new()),
timers: Mutex::new(Vec::new()),
}),
}
}
pub fn enqueued(&self) {
self.inner.in_flight.fetch_add(1, Ordering::SeqCst);
self.inner.notify.notify_waiters();
}
pub fn consumed(&self) {
self.inner.processed.fetch_add(1, Ordering::SeqCst);
self.inner.in_flight.fetch_sub(1, Ordering::SeqCst);
self.inner.notify.notify_waiters();
}
pub(crate) fn record(&self, record: Record) {
self.inner
.records
.lock()
.expect("coordinator records mutex poisoned")
.push(record);
}
pub fn schedule_redelivery<F>(&self, delay: Duration, redeliver: F)
where
F: FnOnce() + Send + 'static,
{
let deadline = tokio::time::Instant::now() + delay;
let handle = tokio::spawn(async move {
tokio::time::sleep(delay).await;
redeliver();
});
self.inner
.timers
.lock()
.expect("coordinator timers mutex poisoned")
.push(Timer { deadline, handle });
}
#[allow(clippy::significant_drop_tightening)]
pub(crate) async fn fire_due_timers(&self) {
let now = tokio::time::Instant::now();
let due: Vec<tokio::task::JoinHandle<()>> = {
let mut timers = self
.inner
.timers
.lock()
.expect("coordinator timers mutex poisoned");
let mut due = Vec::new();
let mut i = 0;
while i < timers.len() {
if timers[i].deadline <= now {
due.push(timers.swap_remove(i).handle);
} else {
i += 1;
}
}
due
};
for handle in due {
let _ = handle.await;
}
}
pub(crate) async fn drive(&self) -> Result<(), TestError> {
loop {
let notified = self.inner.notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
if self.inner.in_flight.load(Ordering::SeqCst) == 0 {
return Ok(());
}
if self.inner.processed.load(Ordering::SeqCst) >= self.inner.max_steps {
return Err(TestError::NotQuiescent {
processed: self.inner.processed.load(Ordering::SeqCst),
});
}
notified.await;
}
}
#[allow(clippy::significant_drop_tightening)]
pub(crate) fn with_records<R>(
&self,
scope_id: usize,
name: &str,
f: impl FnOnce(&[&Record]) -> R,
) -> R {
let guard = self
.inner
.records
.lock()
.expect("coordinator records mutex poisoned");
let matching: Vec<&Record> = guard
.iter()
.filter(|r| r.scope_id == scope_id && r.name == name)
.collect();
f(&matching)
}
}