use std::borrow::Cow;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use serde::de::Visitor;
use serde::ser::SerializeSeq;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct Message<P> {
pub join_reference: Option<String>,
pub message_reference: String,
pub topic_name: String,
pub event_name: String,
pub payload: P,
}
impl<'a, P> From<ChannelMsg<'a, P>> for Message<P> {
fn from(value: ChannelMsg<'a, P>) -> Self {
Self {
join_reference: value.join_reference.map(Cow::into),
message_reference: value.message_reference.into(),
topic_name: value.topic_name.into(),
event_name: value.event_name.into(),
payload: value.payload,
}
}
}
impl<P> Display for Message<P>
where
P: Serialize + Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
ser_or_debug(&self.join_reference, f)?;
write!(f, ", ")?;
ser_or_debug(&self.message_reference, f)?;
write!(f, ", ")?;
ser_or_debug(&self.topic_name, f)?;
write!(f, ", ")?;
ser_or_debug(&self.event_name, f)?;
write!(f, ", ")?;
ser_or_debug(&self.payload, f)?;
write!(f, "]")
}
}
fn ser_or_debug<T>(v: &T, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result
where
T: Serialize + Debug,
{
if let Ok(s) = serde_json::to_string(v) {
write!(f, "{s}")
} else {
write!(f, "{v:?}")
}
}
#[derive(Debug)]
pub(crate) struct ChannelMsg<'a, P> {
pub(crate) join_reference: Option<Cow<'a, str>>,
pub(crate) message_reference: Cow<'a, str>,
pub(crate) topic_name: Cow<'a, str>,
pub(crate) event_name: Cow<'a, str>,
pub(crate) payload: P,
}
impl<P> ChannelMsg<'_, P> {
pub(crate) fn into_err(self) -> Message<()> {
Message {
join_reference: self.join_reference.map(Cow::into),
message_reference: self.message_reference.into(),
topic_name: self.topic_name.into(),
event_name: self.event_name.into(),
payload: (),
}
}
}
impl<P> Serialize for ChannelMsg<'_, P>
where
P: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut s = serializer.serialize_seq(Some(5))?;
s.serialize_element(&self.join_reference)?;
s.serialize_element(&self.message_reference)?;
s.serialize_element(&self.topic_name)?;
s.serialize_element(&self.event_name)?;
s.serialize_element(&self.payload)?;
s.end()
}
}
impl<'de, 'a, P> Deserialize<'de> for ChannelMsg<'a, P>
where
P: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
#[derive(Debug)]
struct ChannelMsgVisitor<'a, P> {
_marker: PhantomData<(Cow<'a, str>, P)>,
}
impl<'de, 'a, P> Visitor<'de> for ChannelMsgVisitor<'a, P>
where
P: Deserialize<'de>,
{
type Value = ChannelMsg<'a, P>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
formatter,
"a sequence of 5 elements for a valid Phoenix channel"
)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
if let Some(len) = seq.size_hint() {
if len != 5 {
return Err(A::Error::invalid_length(len, &"5"));
}
}
let Some(join_reference) = seq.next_element()? else {
return Err(A::Error::invalid_length(0, &"5"));
};
let Some(message_reference) = seq.next_element()? else {
return Err(A::Error::invalid_length(1, &"5"));
};
let Some(topic_name) = seq.next_element()? else {
return Err(A::Error::invalid_length(2, &"5"));
};
let Some(event_name) = seq.next_element()? else {
return Err(A::Error::invalid_length(3, &"5"));
};
let Some(payload) = seq.next_element()? else {
return Err(A::Error::invalid_length(4, &"5"));
};
Ok(ChannelMsg::<P> {
join_reference,
message_reference,
topic_name,
event_name,
payload,
})
}
}
deserializer.deserialize_seq(ChannelMsgVisitor::<'a, P> {
_marker: PhantomData,
})
}
}
impl<P> Display for ChannelMsg<'_, P>
where
P: Serialize + Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Ok(s) = serde_json::to_string(self) else {
return write!(
f,
"[{:?}, {:?}, {:?}, {:?}, {:?}]",
self.join_reference,
self.message_reference,
self.topic_name,
self.event_name,
self.payload,
);
};
write!(f, "{s}")
}
}