use std::{fmt, marker::PhantomData};
use crate::IncomingMessage;
use crate::codec::Codec;
use serde::de::DeserializeOwned;
use tracing::warn;
use super::context::Context;
use super::handler::{Handler, HandlerResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum DecodeFailure {
#[default]
Drop,
Requeue,
}
pub fn typed<M, T, C, H>(codec: C, inner: H) -> Typed<M, T, C, H>
where
M: IncomingMessage,
T: DeserializeOwned + Send + Sync,
C: Codec,
H: Handler<T>,
{
Typed {
codec,
inner,
on_decode_failure: DecodeFailure::default(),
_phantom: PhantomData,
}
}
pub struct Typed<M, T, C, H> {
codec: C,
inner: H,
on_decode_failure: DecodeFailure,
_phantom: PhantomData<fn(M, T)>,
}
impl<M, T, C, H> Typed<M, T, C, H> {
#[must_use]
pub fn on_decode_failure(mut self, mode: DecodeFailure) -> Self {
self.on_decode_failure = mode;
self
}
}
impl<M, T, C, H> fmt::Debug for Typed<M, T, C, H> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Typed")
.field("on_decode_failure", &self.on_decode_failure)
.finish_non_exhaustive()
}
}
impl<M, T, C, H> Handler<M> for Typed<M, T, C, H>
where
M: IncomingMessage,
T: DeserializeOwned + Send + Sync,
C: Codec,
H: Handler<T>,
{
async fn handle(&self, msg: &M, ctx: &mut Context<'_>) -> HandlerResult {
match self.codec.decode::<T>(msg.payload()) {
Ok(value) => self.inner.handle(&value, ctx).await,
Err(err) => {
warn!(
target: "ruststream::dispatch",
error = %err,
"codec decode failed",
);
match self.on_decode_failure {
DecodeFailure::Drop => HandlerResult::drop(),
DecodeFailure::Requeue => HandlerResult::retry(),
}
}
}
}
}