use std::{any::Any, marker::PhantomData};
use bytes::Bytes;
use crate::{backend::IndexKey, entry::PayloadType};
#[derive(Debug, thiserror::Error)]
pub enum EncodeError {
#[error("encoder type mismatch: expected {expected}")]
TypeMismatch {
expected: &'static str,
},
#[error("encode failure: {0}")]
Serialize(String),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct EncodedPayload {
pub payload: Bytes,
pub index_keys: Vec<IndexKey>,
pub payload_type: Option<PayloadType>,
}
impl EncodedPayload {
#[must_use]
pub const fn new(payload: Bytes, index_keys: Vec<IndexKey>) -> Self {
Self {
payload,
index_keys,
payload_type: None,
}
}
#[must_use]
pub const fn without_indices(payload: Bytes) -> Self {
Self {
payload,
index_keys: Vec::new(),
payload_type: None,
}
}
#[must_use]
pub const fn with_payload_type(
payload_type: PayloadType,
payload: Bytes,
index_keys: Vec<IndexKey>,
) -> Self {
Self {
payload,
index_keys,
payload_type: Some(payload_type),
}
}
}
pub trait Encode: Send + Sync {
fn encode(&self, message: &dyn Any) -> Result<EncodedPayload, EncodeError>;
}
pub struct TypedEncoder<T: 'static, F> {
func: F,
_phantom: PhantomData<fn(&T)>,
}
impl<T: 'static, F> TypedEncoder<T, F>
where
F: Fn(&T) -> Result<EncodedPayload, EncodeError> + Send + Sync,
{
#[must_use]
pub const fn new(func: F) -> Self {
Self {
func,
_phantom: PhantomData,
}
}
}
impl<T: 'static, F> std::fmt::Debug for TypedEncoder<T, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(TypedEncoder))
.field("type", &std::any::type_name::<T>())
.finish_non_exhaustive()
}
}
impl<T: 'static, F> Encode for TypedEncoder<T, F>
where
F: Fn(&T) -> Result<EncodedPayload, EncodeError> + Send + Sync,
{
fn encode(&self, message: &dyn Any) -> Result<EncodedPayload, EncodeError> {
let typed = message
.downcast_ref::<T>()
.ok_or(EncodeError::TypeMismatch {
expected: std::any::type_name::<T>(),
})?;
(self.func)(typed)
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use rstest::rstest;
use super::*;
use crate::backend::IndexKind;
#[derive(Debug)]
struct Sample(u8);
#[derive(Debug)]
struct Other;
fn sample_encoder()
-> TypedEncoder<Sample, impl Fn(&Sample) -> Result<EncodedPayload, EncodeError> + Send + Sync>
{
TypedEncoder::<Sample, _>::new(|s: &Sample| {
Ok(EncodedPayload::new(
Bytes::copy_from_slice(&[s.0]),
vec![IndexKey::new(
IndexKind::ClientOrderId,
format!("CLI-{}", s.0),
)],
))
})
}
#[rstest]
fn typed_encoder_encodes_matching_value() {
let encoder = sample_encoder();
let encoded = encoder.encode(&Sample(7)).expect("encode");
assert_eq!(encoded.payload.as_ref(), &[7]);
assert_eq!(encoded.index_keys.len(), 1);
assert_eq!(encoded.index_keys[0].kind, IndexKind::ClientOrderId);
assert_eq!(encoded.index_keys[0].key, "CLI-7");
}
#[rstest]
fn typed_encoder_rejects_other_type() {
let encoder = sample_encoder();
let err = encoder.encode(&Other).expect_err("type mismatch");
match err {
EncodeError::TypeMismatch { expected } => {
assert!(expected.ends_with("Sample"), "expected was: {expected}");
}
EncodeError::Serialize(_) => panic!("expected TypeMismatch, was Serialize"),
}
}
#[rstest]
fn encoded_payload_without_indices_has_empty_indices() {
let payload = EncodedPayload::without_indices(Bytes::from_static(b"abc"));
assert_eq!(payload.payload.as_ref(), b"abc");
assert!(payload.index_keys.is_empty());
}
}