use std::{
any::{Any, TypeId},
collections::HashMap,
fmt::Debug,
marker::PhantomData,
sync::Arc,
};
use crate::{
capture::encoder::{Encode, EncodeError, EncodedPayload, TypedEncoder},
entry::PayloadType,
headers::Headers,
};
pub trait HeadersExtractor: Send + Sync {
fn extract(&self, message: &dyn Any) -> Headers;
}
pub struct TypedHeadersExtractor<T: 'static, F> {
func: F,
_phantom: PhantomData<fn(&T)>,
}
impl<T: 'static, F> TypedHeadersExtractor<T, F>
where
F: Fn(&T) -> Headers + Send + Sync,
{
#[must_use]
pub const fn new(func: F) -> Self {
Self {
func,
_phantom: PhantomData,
}
}
}
impl<T: 'static, F> Debug for TypedHeadersExtractor<T, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(TypedHeadersExtractor))
.field("type", &std::any::type_name::<T>())
.finish_non_exhaustive()
}
}
impl<T: 'static, F> HeadersExtractor for TypedHeadersExtractor<T, F>
where
F: Fn(&T) -> Headers + Send + Sync,
{
fn extract(&self, message: &dyn Any) -> Headers {
message
.downcast_ref::<T>()
.map(&self.func)
.unwrap_or_default()
}
}
#[derive(Debug, Default)]
struct EmptyHeadersExtractor;
impl HeadersExtractor for EmptyHeadersExtractor {
fn extract(&self, _: &dyn Any) -> Headers {
Headers::empty()
}
}
#[derive(Clone)]
struct Registered {
payload_type: PayloadType,
encoder: Arc<dyn Encode>,
headers: Arc<dyn HeadersExtractor>,
}
impl Debug for Registered {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(Registered))
.field("payload_type", &self.payload_type.as_str())
.finish_non_exhaustive()
}
}
#[derive(Clone, Debug, Default)]
pub struct EncoderRegistry {
by_type: HashMap<TypeId, Registered>,
}
impl EncoderRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register<T, F>(&mut self, payload_type: PayloadType, func: F)
where
T: 'static,
F: Fn(&T) -> Result<EncodedPayload, EncodeError> + Send + Sync + 'static,
{
let encoder: Arc<dyn Encode> = Arc::new(TypedEncoder::<T, F>::new(func));
let headers = self
.preserved_headers::<T>()
.unwrap_or_else(|| Arc::new(EmptyHeadersExtractor) as Arc<dyn HeadersExtractor>);
self.by_type.insert(
TypeId::of::<T>(),
Registered {
payload_type,
encoder,
headers,
},
);
}
pub fn register_with_headers<T, F, H>(
&mut self,
payload_type: PayloadType,
func: F,
headers_fn: H,
) where
T: 'static,
F: Fn(&T) -> Result<EncodedPayload, EncodeError> + Send + Sync + 'static,
H: Fn(&T) -> Headers + Send + Sync + 'static,
{
let encoder: Arc<dyn Encode> = Arc::new(TypedEncoder::<T, F>::new(func));
let headers: Arc<dyn HeadersExtractor> =
Arc::new(TypedHeadersExtractor::<T, H>::new(headers_fn));
self.by_type.insert(
TypeId::of::<T>(),
Registered {
payload_type,
encoder,
headers,
},
);
}
pub fn register_encoder<T: 'static>(
&mut self,
payload_type: PayloadType,
encoder: Arc<dyn Encode>,
) {
let headers = self
.preserved_headers::<T>()
.unwrap_or_else(|| Arc::new(EmptyHeadersExtractor) as Arc<dyn HeadersExtractor>);
self.by_type.insert(
TypeId::of::<T>(),
Registered {
payload_type,
encoder,
headers,
},
);
}
pub fn register_headers<T, H>(&mut self, headers_fn: H)
where
T: 'static,
H: Fn(&T) -> Headers + Send + Sync + 'static,
{
if let Some(reg) = self.by_type.get_mut(&TypeId::of::<T>()) {
reg.headers = Arc::new(TypedHeadersExtractor::<T, H>::new(headers_fn));
}
}
fn preserved_headers<T: 'static>(&self) -> Option<Arc<dyn HeadersExtractor>> {
self.by_type
.get(&TypeId::of::<T>())
.map(|reg| Arc::clone(®.headers))
}
#[must_use]
pub fn len(&self) -> usize {
self.by_type.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.by_type.is_empty()
}
#[must_use]
pub fn contains<T: 'static>(&self) -> bool {
self.by_type.contains_key(&TypeId::of::<T>())
}
pub fn encode<T: 'static>(
&self,
message: &T,
) -> Result<Option<(PayloadType, EncodedPayload)>, EncodeError> {
let Some(reg) = self.by_type.get(&TypeId::of::<T>()) else {
return Ok(None);
};
let encoded = reg.encoder.encode(message as &dyn Any)?;
let payload_type = encoded.payload_type.unwrap_or(reg.payload_type);
Ok(Some((payload_type, encoded)))
}
pub fn encode_any(
&self,
message: &dyn Any,
) -> Result<Option<(PayloadType, EncodedPayload)>, EncodeError> {
let Some(reg) = self.by_type.get(&message.type_id()) else {
return Ok(None);
};
let encoded = reg.encoder.encode(message)?;
let payload_type = encoded.payload_type.unwrap_or(reg.payload_type);
Ok(Some((payload_type, encoded)))
}
#[must_use]
pub fn headers_for_any(&self, message: &dyn Any) -> Option<Headers> {
self.by_type
.get(&message.type_id())
.map(|reg| reg.headers.extract(message))
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use rstest::rstest;
use ustr::Ustr;
use super::*;
#[derive(Debug)]
struct Sample(u8);
#[derive(Debug)]
struct Other;
#[rstest]
fn unknown_type_returns_none() {
let registry = EncoderRegistry::new();
assert!(registry.encode(&Sample(1)).expect("encode").is_none());
assert!(!registry.contains::<Sample>());
}
#[rstest]
fn registered_type_returns_payload_type_and_payload() {
let mut registry = EncoderRegistry::new();
registry.register::<Sample, _>(Ustr::from("Sample"), |s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0,
])))
});
let (tag, encoded) = registry.encode(&Sample(9)).expect("encode").expect("hit");
assert_eq!(tag.as_str(), "Sample");
assert_eq!(encoded.payload.as_ref(), &[9]);
assert!(registry.contains::<Sample>());
assert_eq!(registry.len(), 1);
}
#[rstest]
fn re_registering_replaces_prior_encoder() {
let mut registry = EncoderRegistry::new();
registry.register::<Sample, _>(Ustr::from("Old"), |s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0,
])))
});
registry.register::<Sample, _>(Ustr::from("New"), |s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0, s.0,
])))
});
let (tag, encoded) = registry.encode(&Sample(3)).expect("encode").expect("hit");
assert_eq!(tag.as_str(), "New");
assert_eq!(encoded.payload.as_ref(), &[3, 3]);
assert_eq!(registry.len(), 1);
}
#[rstest]
fn registry_is_empty_by_default() {
let registry = EncoderRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
assert!(!registry.contains::<Other>());
}
#[rstest]
fn encode_any_dispatches_by_concrete_type_id() {
let mut registry = EncoderRegistry::new();
registry.register::<Sample, _>(Ustr::from("Sample"), |s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0,
])))
});
let sample = Sample(5);
let (tag, encoded) = registry
.encode_any(&sample as &dyn Any)
.expect("encode_any")
.expect("hit");
assert_eq!(tag.as_str(), "Sample");
assert_eq!(encoded.payload.as_ref(), &[5]);
}
#[rstest]
fn encode_any_returns_none_for_unregistered_type() {
let registry = EncoderRegistry::new();
let unregistered = Other;
let outcome = registry
.encode_any(&unregistered as &dyn Any)
.expect("encode_any");
assert!(outcome.is_none());
}
#[rstest]
fn encoder_payload_type_override_overrides_registered_tag() {
let mut registry = EncoderRegistry::new();
registry.register::<Sample, _>(Ustr::from("Wrapper"), |s| {
Ok(EncodedPayload::with_payload_type(
Ustr::from("Inner"),
Bytes::copy_from_slice(&[s.0]),
Vec::new(),
))
});
let (tag, _) = registry.encode(&Sample(1)).expect("encode").expect("hit");
assert_eq!(tag.as_str(), "Inner");
let (any_tag, _) = registry
.encode_any(&Sample(1) as &dyn Any)
.expect("encode_any")
.expect("hit");
assert_eq!(any_tag.as_str(), "Inner");
}
#[rstest]
fn registered_type_without_headers_extractor_returns_empty_headers() {
let mut registry = EncoderRegistry::new();
registry.register::<Sample, _>(Ustr::from("Sample"), |s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0,
])))
});
let headers = registry
.headers_for_any(&Sample(1) as &dyn Any)
.expect("hit");
assert_eq!(headers, Headers::empty());
}
#[rstest]
fn headers_for_any_returns_none_for_unregistered_type() {
let registry = EncoderRegistry::new();
let outcome = registry.headers_for_any(&Other as &dyn Any);
assert!(outcome.is_none());
}
#[rstest]
fn register_with_headers_uses_extractor() {
let mut registry = EncoderRegistry::new();
let causation = nautilus_core::UUID4::new();
let causation_captured = causation;
registry.register_with_headers::<Sample, _, _>(
Ustr::from("Sample"),
|s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0,
])))
},
move |_| Headers {
correlation_id: None,
causation_id: Some(causation_captured),
},
);
let headers = registry
.headers_for_any(&Sample(1) as &dyn Any)
.expect("hit");
assert_eq!(headers.causation_id, Some(causation));
}
#[rstest]
fn register_headers_overrides_default_extractor_post_register() {
let mut registry = EncoderRegistry::new();
registry.register::<Sample, _>(Ustr::from("Sample"), |s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0,
])))
});
let correlation = nautilus_core::UUID4::new();
let correlation_captured = correlation;
registry.register_headers::<Sample, _>(move |_| Headers {
correlation_id: Some(correlation_captured),
causation_id: None,
});
let headers = registry
.headers_for_any(&Sample(1) as &dyn Any)
.expect("hit");
assert_eq!(headers.correlation_id, Some(correlation));
}
#[rstest]
fn register_headers_for_unregistered_type_is_silent_noop() {
let mut registry = EncoderRegistry::new();
registry.register_headers::<Sample, _>(|_| Headers::empty());
assert!(!registry.contains::<Sample>());
assert!(registry.headers_for_any(&Sample(1) as &dyn Any).is_none());
}
#[rstest]
fn re_registering_preserves_existing_headers_extractor() {
let mut registry = EncoderRegistry::new();
registry.register::<Sample, _>(Ustr::from("Old"), |s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0,
])))
});
let causation = nautilus_core::UUID4::new();
let causation_captured = causation;
registry.register_headers::<Sample, _>(move |_| Headers {
correlation_id: None,
causation_id: Some(causation_captured),
});
registry.register::<Sample, _>(Ustr::from("New"), |s| {
Ok(EncodedPayload::without_indices(Bytes::copy_from_slice(&[
s.0, s.0,
])))
});
let (tag, _) = registry.encode(&Sample(3)).expect("encode").expect("hit");
assert_eq!(tag.as_str(), "New");
let headers = registry
.headers_for_any(&Sample(3) as &dyn Any)
.expect("hit");
assert_eq!(headers.causation_id, Some(causation));
}
}