use std::future::Future;
use std::pin::Pin;
use crate::{Field, FieldMut, Headers};
use super::dispatch::Delivery;
use super::failure::ErrorShutdown;
use super::handler::HandlerResult;
type Continuation = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum OutcomeKind {
Ack,
Drop,
Retry,
RetryAfter,
}
impl OutcomeKind {
fn of(outcome: HandlerResult) -> Self {
match outcome {
HandlerResult::Ack => Self::Ack,
HandlerResult::Nack { requeue: false } => Self::Drop,
HandlerResult::Nack { requeue: true } => Self::Retry,
HandlerResult::NackAfter { .. } => Self::RetryAfter,
}
}
}
struct AfterHook {
gate: Option<OutcomeKind>,
fut: Continuation,
}
pub struct Context<'a, C = (), S = ()> {
name: &'a str,
original: &'a Headers,
modified: Option<Headers>,
state: &'a S,
cx: C,
delivery: &'a Delivery,
after: Vec<AfterHook>,
failfast: Option<&'a ErrorShutdown>,
#[cfg(feature = "testing")]
decode_failed: bool,
}
impl<C, S> std::fmt::Debug for Context<'_, C, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("name", &self.name)
.field("after_hooks", &self.after.len())
.finish_non_exhaustive()
}
}
impl<'a, C, S> Context<'a, C, S> {
pub(crate) fn new(
name: &'a str,
headers: &'a Headers,
state: &'a S,
cx: C,
delivery: &'a Delivery,
) -> Self {
Self {
name,
original: headers,
modified: None,
state,
cx,
delivery,
after: Vec::new(),
failfast: None,
#[cfg(feature = "testing")]
decode_failed: false,
}
}
#[cfg(feature = "testing")]
pub(crate) fn mark_decode_failed(&mut self) {
self.decode_failed = true;
}
#[cfg(feature = "testing")]
pub(crate) fn took_decode_failed(&mut self) -> bool {
std::mem::take(&mut self.decode_failed)
}
#[must_use]
pub(crate) fn with_failfast(mut self, failfast: &'a ErrorShutdown) -> Self {
self.failfast = Some(failfast);
self
}
pub(crate) fn fail_fast(&self, reason: &str) {
if let Some(failfast) = self.failfast {
failfast.signal(self.name, reason);
}
}
#[must_use]
pub fn name(&self) -> &str {
self.name
}
#[must_use]
pub fn headers(&self) -> &Headers {
self.modified.as_ref().unwrap_or(self.original)
}
pub fn headers_mut(&mut self) -> &mut Headers {
self.modified.get_or_insert_with(|| self.original.clone())
}
#[must_use]
pub fn state(&self) -> &S {
self.state
}
pub fn context<K: Field<C>>(&self, key: K) -> K::Value<'_> {
key.get(&self.cx)
}
pub(crate) fn cx_ref(&self) -> &C {
&self.cx
}
pub fn set<K: FieldMut<C>>(&mut self, key: K, value: K::Owned) {
key.set(&mut self.cx, value);
}
pub fn after(&mut self, outcome: HandlerResult) -> After<'_, 'a, C, S> {
After {
ctx: self,
gate: Some(OutcomeKind::of(outcome)),
}
}
pub fn after_ack(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
self.after.push(AfterHook {
gate: Some(OutcomeKind::Ack),
fut: Box::pin(fut),
});
}
pub fn after_settle(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
self.after.push(AfterHook {
gate: None,
fut: Box::pin(fut),
});
}
pub(crate) fn take_hooks_for(&mut self, outcome: HandlerResult) -> Vec<Continuation> {
let kind = OutcomeKind::of(outcome);
let mut runnable = Vec::new();
let mut kept = Vec::new();
for hook in self.after.drain(..) {
if hook.gate.is_none_or(|gate| gate == kind) {
runnable.push(hook.fut);
} else {
kept.push(hook);
}
}
self.after = kept;
runnable
}
pub(crate) fn take_settle_hooks(&mut self) -> Vec<Continuation> {
let mut runnable = Vec::new();
let mut kept = Vec::new();
for hook in self.after.drain(..) {
if hook.gate.is_none() {
runnable.push(hook.fut);
} else {
kept.push(hook);
}
}
self.after = kept;
runnable
}
pub(crate) fn tasks(&self) -> &tokio_util::task::TaskTracker {
&self.delivery.tasks
}
}
#[must_use = "call `.then(fut)` to register the post-settle hook"]
pub struct After<'ctx, 'a, C = (), S = ()> {
ctx: &'ctx mut Context<'a, C, S>,
gate: Option<OutcomeKind>,
}
impl<C, S> std::fmt::Debug for After<'_, '_, C, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("After").field("gate", &self.gate).finish()
}
}
impl<C, S> After<'_, '_, C, S> {
pub fn then(self, fut: impl Future<Output = ()> + Send + 'static) {
self.ctx.after.push(AfterHook {
gate: self.gate,
fut: Box::pin(fut),
});
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use futures::future::join_all;
use super::Context;
use crate::Headers;
use crate::runtime::dispatch::Delivery;
use crate::runtime::handler::HandlerResult;
fn run_all(continuations: Vec<super::Continuation>) {
futures::executor::block_on(async {
join_all(continuations).await;
});
}
#[test]
fn outcome_kind_distinguishes_drop_retry_and_retry_after() {
use super::OutcomeKind;
assert_eq!(OutcomeKind::of(HandlerResult::Ack), OutcomeKind::Ack);
assert_eq!(OutcomeKind::of(HandlerResult::drop()), OutcomeKind::Drop);
assert_eq!(OutcomeKind::of(HandlerResult::retry()), OutcomeKind::Retry);
assert_ne!(
OutcomeKind::of(HandlerResult::drop()),
OutcomeKind::of(HandlerResult::retry()),
);
assert_eq!(
OutcomeKind::of(HandlerResult::retry_after(Duration::from_secs(1))),
OutcomeKind::RetryAfter,
);
assert_ne!(
OutcomeKind::of(HandlerResult::retry_after(Duration::ZERO)),
OutcomeKind::of(HandlerResult::retry()),
);
assert_eq!(
OutcomeKind::of(HandlerResult::retry_after(Duration::from_secs(1))),
OutcomeKind::of(HandlerResult::retry_after(Duration::from_secs(9))),
);
}
#[test]
fn take_hooks_runs_only_the_matching_gate() {
let state = ();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("t", &headers, &state, (), &delivery);
let acked = Arc::new(AtomicU32::new(0));
let dropped = Arc::new(AtomicU32::new(0));
let retried = Arc::new(AtomicU32::new(0));
let settled = Arc::new(AtomicU32::new(0));
let bump = |c: &Arc<AtomicU32>| {
let c = Arc::clone(c);
async move {
c.fetch_add(1, Ordering::SeqCst);
}
};
ctx.after(HandlerResult::Ack).then(bump(&acked));
ctx.after(HandlerResult::drop()).then(bump(&dropped));
ctx.after(HandlerResult::retry()).then(bump(&retried));
ctx.after_ack(bump(&acked));
ctx.after_settle(bump(&settled));
run_all(ctx.take_hooks_for(HandlerResult::Ack));
assert_eq!(acked.load(Ordering::SeqCst), 2);
assert_eq!(settled.load(Ordering::SeqCst), 1);
assert_eq!(dropped.load(Ordering::SeqCst), 0);
assert_eq!(retried.load(Ordering::SeqCst), 0);
run_all(ctx.take_hooks_for(HandlerResult::retry()));
assert_eq!(retried.load(Ordering::SeqCst), 1);
assert_eq!(dropped.load(Ordering::SeqCst), 0);
run_all(ctx.take_hooks_for(HandlerResult::drop()));
assert_eq!(dropped.load(Ordering::SeqCst), 1);
assert_eq!(retried.load(Ordering::SeqCst), 1);
}
#[test]
fn take_settle_hooks_drops_outcome_gated_ones() {
let state = ();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("t", &headers, &state, (), &delivery);
let gated = Arc::new(AtomicU32::new(0));
let ungated = Arc::new(AtomicU32::new(0));
let gated_clone = Arc::clone(&gated);
ctx.after(HandlerResult::Ack).then(async move {
gated_clone.fetch_add(1, Ordering::SeqCst);
});
let ungated_clone = Arc::clone(&ungated);
ctx.after_settle(async move {
ungated_clone.fetch_add(1, Ordering::SeqCst);
});
run_all(ctx.take_settle_hooks());
assert_eq!(ungated.load(Ordering::SeqCst), 1);
assert_eq!(gated.load(Ordering::SeqCst), 0);
}
#[test]
fn context_reads_typed_field_by_key() {
use crate::Field;
struct Meta {
offset: u64,
}
#[derive(Clone, Copy)]
struct Offset;
impl Field<Meta> for Offset {
type Value<'a> = u64;
fn get(self, m: &Meta) -> u64 {
m.offset
}
}
let state = String::from("app");
let delivery = Delivery::empty();
let headers = Headers::new();
let ctx = Context::new("test", &headers, &state, Meta { offset: 42 }, &delivery);
assert_eq!(ctx.context(Offset), 42);
assert_eq!(ctx.state().as_str(), "app");
}
#[test]
fn set_writes_scratch_and_reads_it_back() {
use crate::{Field, FieldMut};
#[derive(Default)]
struct Scratch {
user: Option<u64>,
}
#[derive(Clone, Copy)]
struct User;
impl Field<Scratch> for User {
type Value<'a> = Option<&'a u64>;
fn get(self, s: &Scratch) -> Option<&u64> {
s.user.as_ref()
}
}
impl FieldMut<Scratch> for User {
type Owned = u64;
fn set(self, s: &mut Scratch, value: u64) {
s.user = Some(value);
}
}
let state = ();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("test", &headers, &state, Scratch::default(), &delivery);
assert_eq!(ctx.context(User), None);
ctx.set(User, 9);
assert_eq!(ctx.context(User), Some(&9));
}
#[test]
fn headers_clone_only_on_first_mutation() {
let mut original = Headers::new();
original.insert("k", "v");
let state = ();
let delivery = Delivery::empty();
let mut ctx = Context::new("test", &original, &state, (), &delivery);
assert!(std::ptr::eq(ctx.headers(), &raw const original));
ctx.headers_mut().insert("added", "1");
ctx.headers_mut().insert("added2", "2");
assert!(!std::ptr::eq(ctx.headers(), &raw const original));
assert_eq!(ctx.headers().get("added"), Some(&b"1"[..]));
assert_eq!(ctx.headers().get("k"), Some(&b"v"[..]));
assert_eq!(original.get("added"), None);
}
}