use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use crate::{Extensions, Headers};
use super::dispatch::Delivery;
use super::failure::ErrorShutdown;
use super::handler::HandlerResult;
use super::publish::ScopedPublisher;
type Continuation = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum OutcomeKind {
Ack,
Drop,
Retry,
RetryAfter,
}
impl OutcomeKind {
fn of(outcome: HandlerResult) -> Self {
match outcome {
HandlerResult::Ack => Self::Ack,
HandlerResult::Nack { requeue: false } => Self::Drop,
HandlerResult::Nack { requeue: true } => Self::Retry,
HandlerResult::NackAfter { .. } => Self::RetryAfter,
}
}
}
struct AfterHook {
gate: Option<OutcomeKind>,
fut: Continuation,
}
#[derive(Default)]
pub struct State {
map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}
impl State {
pub fn insert<T: Any + Send + Sync>(&mut self, value: T) {
self.map.insert(TypeId::of::<T>(), Box::new(value));
}
#[must_use]
pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> {
self.map.get(&TypeId::of::<T>())?.downcast_ref::<T>()
}
}
impl std::fmt::Debug for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("State")
.field("entries", &self.map.len())
.finish_non_exhaustive()
}
}
pub struct Context<'a> {
name: &'a str,
original: &'a Headers,
modified: Option<Headers>,
state: &'a State,
extensions: Extensions,
delivery: &'a Delivery,
after: Vec<AfterHook>,
failfast: Option<&'a ErrorShutdown>,
}
impl std::fmt::Debug for Context<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("name", &self.name)
.field("after_hooks", &self.after.len())
.finish_non_exhaustive()
}
}
impl<'a> Context<'a> {
pub(crate) fn new(
name: &'a str,
headers: &'a Headers,
state: &'a State,
delivery: &'a Delivery,
) -> Self {
Self::with_extensions(name, headers, state, Extensions::new(), delivery)
}
pub(crate) fn with_extensions(
name: &'a str,
headers: &'a Headers,
state: &'a State,
extensions: Extensions,
delivery: &'a Delivery,
) -> Self {
Self {
name,
original: headers,
modified: None,
state,
extensions,
delivery,
after: Vec::new(),
failfast: None,
}
}
#[must_use]
pub(crate) fn with_failfast(mut self, failfast: &'a ErrorShutdown) -> Self {
self.failfast = Some(failfast);
self
}
pub(crate) fn fail_fast(&self, reason: &str) {
if let Some(failfast) = self.failfast {
failfast.signal(self.name, reason);
}
}
#[must_use]
pub fn name(&self) -> &str {
self.name
}
#[must_use]
pub fn publisher(&self, name: &str) -> Option<ScopedPublisher<'_>> {
let publisher = self.delivery.publishers.get(name)?;
Some(ScopedPublisher::new(
publisher.as_ref(),
&self.delivery.pipeline,
&self.extensions,
))
}
#[must_use]
pub fn headers(&self) -> &Headers {
self.modified.as_ref().unwrap_or(self.original)
}
pub fn headers_mut(&mut self) -> &mut Headers {
self.modified.get_or_insert_with(|| self.original.clone())
}
#[must_use]
pub fn state(&self) -> &State {
self.state
}
#[must_use]
pub fn get<T: Any + Send + Sync>(&self) -> Option<&T> {
self.extensions.get::<T>()
}
pub fn insert<T: Any + Send + Sync>(&mut self, value: T) {
self.extensions.insert(value);
}
pub(crate) fn extensions(&self) -> &Extensions {
&self.extensions
}
pub fn after(&mut self, outcome: HandlerResult) -> After<'_, 'a> {
After {
ctx: self,
gate: Some(OutcomeKind::of(outcome)),
}
}
pub fn after_ack(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
self.after.push(AfterHook {
gate: Some(OutcomeKind::Ack),
fut: Box::pin(fut),
});
}
pub fn after_settle(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
self.after.push(AfterHook {
gate: None,
fut: Box::pin(fut),
});
}
pub(crate) fn take_hooks_for(&mut self, outcome: HandlerResult) -> Vec<Continuation> {
let kind = OutcomeKind::of(outcome);
let mut runnable = Vec::new();
let mut kept = Vec::new();
for hook in self.after.drain(..) {
if hook.gate.is_none_or(|gate| gate == kind) {
runnable.push(hook.fut);
} else {
kept.push(hook);
}
}
self.after = kept;
runnable
}
pub(crate) fn take_settle_hooks(&mut self) -> Vec<Continuation> {
let mut runnable = Vec::new();
let mut kept = Vec::new();
for hook in self.after.drain(..) {
if hook.gate.is_none() {
runnable.push(hook.fut);
} else {
kept.push(hook);
}
}
self.after = kept;
runnable
}
pub(crate) fn tasks(&self) -> &tokio_util::task::TaskTracker {
&self.delivery.tasks
}
}
#[must_use = "call `.then(fut)` to register the post-settle hook"]
pub struct After<'ctx, 'a> {
ctx: &'ctx mut Context<'a>,
gate: Option<OutcomeKind>,
}
impl std::fmt::Debug for After<'_, '_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("After").field("gate", &self.gate).finish()
}
}
impl After<'_, '_> {
pub fn then(self, fut: impl Future<Output = ()> + Send + 'static) {
self.ctx.after.push(AfterHook {
gate: self.gate,
fut: Box::pin(fut),
});
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use futures::future::join_all;
use super::{Context, Extensions, State};
use crate::Headers;
use crate::runtime::dispatch::Delivery;
use crate::runtime::handler::HandlerResult;
fn run_all(continuations: Vec<super::Continuation>) {
futures::executor::block_on(async {
join_all(continuations).await;
});
}
#[test]
fn outcome_kind_distinguishes_drop_retry_and_retry_after() {
use super::OutcomeKind;
assert_eq!(OutcomeKind::of(HandlerResult::Ack), OutcomeKind::Ack);
assert_eq!(OutcomeKind::of(HandlerResult::drop()), OutcomeKind::Drop);
assert_eq!(OutcomeKind::of(HandlerResult::retry()), OutcomeKind::Retry);
assert_ne!(
OutcomeKind::of(HandlerResult::drop()),
OutcomeKind::of(HandlerResult::retry()),
);
assert_eq!(
OutcomeKind::of(HandlerResult::retry_after(Duration::from_secs(1))),
OutcomeKind::RetryAfter,
);
assert_ne!(
OutcomeKind::of(HandlerResult::retry_after(Duration::ZERO)),
OutcomeKind::of(HandlerResult::retry()),
);
assert_eq!(
OutcomeKind::of(HandlerResult::retry_after(Duration::from_secs(1))),
OutcomeKind::of(HandlerResult::retry_after(Duration::from_secs(9))),
);
}
#[test]
fn take_hooks_runs_only_the_matching_gate() {
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("t", &headers, &state, &delivery);
let acked = Arc::new(AtomicU32::new(0));
let dropped = Arc::new(AtomicU32::new(0));
let retried = Arc::new(AtomicU32::new(0));
let settled = Arc::new(AtomicU32::new(0));
let bump = |c: &Arc<AtomicU32>| {
let c = Arc::clone(c);
async move {
c.fetch_add(1, Ordering::SeqCst);
}
};
ctx.after(HandlerResult::Ack).then(bump(&acked));
ctx.after(HandlerResult::drop()).then(bump(&dropped));
ctx.after(HandlerResult::retry()).then(bump(&retried));
ctx.after_ack(bump(&acked));
ctx.after_settle(bump(&settled));
run_all(ctx.take_hooks_for(HandlerResult::Ack));
assert_eq!(acked.load(Ordering::SeqCst), 2);
assert_eq!(settled.load(Ordering::SeqCst), 1);
assert_eq!(dropped.load(Ordering::SeqCst), 0);
assert_eq!(retried.load(Ordering::SeqCst), 0);
run_all(ctx.take_hooks_for(HandlerResult::retry()));
assert_eq!(retried.load(Ordering::SeqCst), 1);
assert_eq!(dropped.load(Ordering::SeqCst), 0);
run_all(ctx.take_hooks_for(HandlerResult::drop()));
assert_eq!(dropped.load(Ordering::SeqCst), 1);
assert_eq!(retried.load(Ordering::SeqCst), 1);
}
#[test]
fn take_settle_hooks_drops_outcome_gated_ones() {
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("t", &headers, &state, &delivery);
let gated = Arc::new(AtomicU32::new(0));
let ungated = Arc::new(AtomicU32::new(0));
let gated_clone = Arc::clone(&gated);
ctx.after(HandlerResult::Ack).then(async move {
gated_clone.fetch_add(1, Ordering::SeqCst);
});
let ungated_clone = Arc::clone(&ungated);
ctx.after_settle(async move {
ungated_clone.fetch_add(1, Ordering::SeqCst);
});
run_all(ctx.take_settle_hooks());
assert_eq!(ungated.load(Ordering::SeqCst), 1);
assert_eq!(gated.load(Ordering::SeqCst), 0);
}
#[test]
fn extensions_one_value_per_type_and_isolation() {
let mut ext = Extensions::new();
assert!(ext.is_empty());
ext.insert(1u32);
ext.insert(2u32);
assert_eq!(ext.get::<u32>(), Some(&2));
ext.insert("tag");
assert_eq!(ext.get::<&str>(), Some(&"tag"));
assert_eq!(ext.get::<i64>(), None);
}
#[test]
fn ctx_get_insert_hit_extensions_state_unaffected() {
let mut state = State::default();
state.insert(String::from("app"));
let headers = Headers::new();
let delivery = Delivery::empty();
let mut ctx = Context::new("test", &headers, &state, &delivery);
assert_eq!(ctx.get::<u32>(), None);
ctx.insert(99u32);
assert_eq!(ctx.get::<u32>(), Some(&99));
assert_eq!(ctx.state().get::<String>().map(String::as_str), Some("app"));
assert_eq!(ctx.get::<String>(), None);
}
#[test]
fn seeded_extensions_reach_the_context() {
let state = State::default();
let headers = Headers::new();
let delivery = Delivery::empty();
let mut seed = Extensions::new();
seed.insert(7u8);
let ctx = Context::with_extensions("test", &headers, &state, seed, &delivery);
assert_eq!(ctx.get::<u8>(), Some(&7));
}
#[test]
fn headers_clone_only_on_first_mutation() {
let mut original = Headers::new();
original.insert("k", "v");
let state = State::default();
let delivery = Delivery::empty();
let mut ctx = Context::new("test", &original, &state, &delivery);
assert!(std::ptr::eq(ctx.headers(), &raw const original));
ctx.headers_mut().insert("added", "1");
ctx.headers_mut().insert("added2", "2");
assert!(!std::ptr::eq(ctx.headers(), &raw const original));
assert_eq!(ctx.headers().get("added"), Some(&b"1"[..]));
assert_eq!(ctx.headers().get("k"), Some(&b"v"[..]));
assert_eq!(original.get("added"), None);
}
}