use crate::error::Error;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
use tokio::{select, sync::oneshot};
use tokio_tungstenite::tungstenite;
use tracing::warn;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Message<T, V, P, R> {
pub join_ref: Option<u64>,
pub reference: Option<u64>,
pub topic: T,
pub event: Event<V>,
pub payload: Option<Payload<P, R>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Payload<P, R> {
PushReply {
status: PushStatus,
response: P,
},
Custom(R),
}
impl<P, R> Payload<P, R> {
pub fn map_push_reply<Q, F>(self, f: F) -> Payload<Q, R>
where
F: FnOnce(P) -> Q,
{
match self {
Payload::PushReply { status, response } => Payload::PushReply {
status,
response: f(response),
},
Payload::Custom(c) => Payload::Custom(c),
}
}
pub fn map_custom<S, F>(self, f: F) -> Payload<P, S>
where
F: FnOnce(R) -> S,
{
match self {
Payload::PushReply { status, response } => Payload::PushReply { status, response },
Payload::Custom(c) => Payload::Custom(f(c)),
}
}
pub fn try_map_push_reply<Q, F, E>(self, f: F) -> Result<Payload<Q, R>, E>
where
F: FnOnce(P) -> Result<Q, E>,
{
match self {
Payload::PushReply { status, response } => Ok(Payload::PushReply {
status,
response: f(response)?,
}),
Payload::Custom(c) => Ok(Payload::Custom(c)),
}
}
pub fn try_map_custom<S, F, E>(self, f: F) -> Result<Payload<P, S>, E>
where
F: FnOnce(R) -> Result<S, E>,
{
match self {
Payload::PushReply { status, response } => Ok(Payload::PushReply { status, response }),
Payload::Custom(c) => Ok(Payload::Custom(f(c)?)),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PushStatus {
Ok,
Error,
}
impl Message<String, Value, Value, Value> {
pub fn heartbeat(reference: u64) -> Self {
Self {
join_ref: Some(0),
reference: Some(reference),
topic: "phoenix".to_string(),
event: Event::Protocol(ProtocolEvent::Heartbeat),
payload: None,
}
}
}
impl<T> Message<T, Value, Value, Value> {
pub fn leave(topic: T, reference: u64) -> Self {
Self {
join_ref: Some(0),
reference: Some(reference),
topic,
event: Event::Protocol(ProtocolEvent::Leave),
payload: None,
}
}
}
impl<T, V, P, R> Message<T, V, P, R>
where
T: Serialize,
V: Serialize,
P: Serialize,
R: Serialize,
{
pub fn new(
join_ref: u64,
reference: u64,
topic: T,
event: Event<V>,
payload: Option<Payload<P, R>>,
) -> Self {
Self {
join_ref: Some(join_ref),
reference: Some(reference),
topic,
event,
payload,
}
}
pub(crate) fn join(reference: u64, topic: T, payload: Option<R>) -> Self {
Self::new(
0,
reference,
topic,
Event::Protocol(ProtocolEvent::Join),
payload.map(|r| Payload::Custom(r)),
)
}
pub fn into_custom_payload(self) -> Option<R> {
match self.payload {
Some(Payload::Custom(c)) => Some(c),
_ => None,
}
}
}
impl<T, V, P, R> TryFrom<Message<T, V, P, R>> for tungstenite::Message
where
T: Serialize,
V: Serialize,
P: Serialize,
R: Serialize,
{
type Error = serde_json::Error;
fn try_from(value: Message<T, V, P, R>) -> Result<Self, Self::Error> {
Ok(Self::Text(serde_json::to_string(&value)?))
}
}
impl<T, V, P, R> TryFrom<tungstenite::Message> for Message<T, V, P, R>
where
T: DeserializeOwned,
V: DeserializeOwned,
P: DeserializeOwned,
R: DeserializeOwned,
{
type Error = serde_json::Error;
fn try_from(value: tungstenite::Message) -> Result<Self, Self::Error> {
match value {
tungstenite::Message::Text(t) => serde_json::from_str(&t),
_ => unreachable!(),
}
}
}
impl<T, V, P, R> Serialize for Message<T, V, P, R>
where
T: Serialize,
V: Serialize,
P: Serialize,
R: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
(
self.join_ref,
self.reference,
&self.topic,
&self.event,
&self.payload,
)
.serialize(serializer)
}
}
impl<'de, T, V, P, R> Deserialize<'de> for Message<T, V, P, R>
where
T: DeserializeOwned,
V: DeserializeOwned,
P: DeserializeOwned,
R: DeserializeOwned,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Deserialize::deserialize(deserializer).map(
|(join_ref, reference, topic, event, payload)| Message {
join_ref,
reference,
topic,
event,
payload,
},
)
}
}
#[derive(Debug)]
pub(crate) struct WithCallback<T> {
pub(crate) content: T,
pub(crate) callback: oneshot::Sender<Result<(), tungstenite::Error>>,
}
impl<T> WithCallback<T> {
pub(crate) fn new(content: T) -> (Self, oneshot::Receiver<Result<(), tungstenite::Error>>) {
let (tx, rx) = oneshot::channel();
(
Self {
content,
callback: tx,
},
rx,
)
}
pub(crate) fn map<U, F>(self, f: F) -> WithCallback<U>
where
F: FnOnce(T) -> U,
{
WithCallback {
content: f(self.content),
callback: self.callback,
}
}
pub(crate) fn try_map<U, F, E>(self, f: F) -> Result<WithCallback<U>, E>
where
F: FnOnce(T) -> Result<U, E>,
{
Ok(WithCallback {
content: f(self.content)?,
callback: self.callback,
})
}
}
pub(crate) async fn run_message_with_timeout<T, V, P, R>(
send_callback: oneshot::Receiver<Result<(), tungstenite::Error>>,
mut receive_callback: oneshot::Receiver<Result<Message<T, V, P, R>, Error>>,
timeout: Duration,
) -> Result<Message<T, V, P, R>, Error>
where
T: Serialize,
V: Serialize,
P: Serialize,
R: Serialize,
{
let timeout = tokio::time::sleep(timeout);
tokio::pin!(timeout);
select! {
_ = &mut timeout => {
warn!("timeout on send callback");
Err(Error::Timeout)
},
Err(_) = &mut receive_callback => Err(Error::SocketDropped),
v = send_callback => {
match v {
Err(_) => Err(Error::SocketDropped),
Ok(Err(e)) => Err(Error::WebSocket(e)),
_ => Ok(())
}
}
}?;
select! {
_ = &mut timeout => {
Err(Error::Timeout)
}
v = receive_callback => {
match v {
Err(_) => Err(Error::SocketDropped),
Ok(v) => v,
}
}
}
}
pub(crate) async fn run_message<T, V, P, R>(
send_callback: oneshot::Receiver<Result<(), tungstenite::Error>>,
mut receive_callback: oneshot::Receiver<Result<Message<T, V, P, R>, Error>>,
) -> Result<Message<T, V, P, R>, Error>
where
T: Serialize,
V: Serialize,
P: Serialize,
R: Serialize,
{
select! {
Err(_) = &mut receive_callback => Err(Error::SocketDropped),
v = send_callback => {
match v {
Err(_) => Err(Error::SocketDropped),
Ok(Err(e)) => Err(Error::WebSocket(e)),
_ => Ok(())
}
}
}?;
select! {
v = receive_callback => {
match v {
Err(_) => Err(Error::SocketDropped),
Ok(v) => v,
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Event<T> {
Protocol(ProtocolEvent),
Event(T),
}
impl<T> Event<T> {
pub fn map<U, F>(self, f: F) -> Event<U>
where
F: FnOnce(T) -> U,
{
match self {
Event::Protocol(e) => Event::Protocol(e),
Event::Event(e) => Event::Event(f(e)),
}
}
pub fn try_map<U, F, E>(self, f: F) -> Result<Event<U>, E>
where
F: FnOnce(T) -> Result<U, E>,
{
match self {
Event::Protocol(e) => Ok(Event::Protocol(e)),
Event::Event(e) => Ok(Event::Event(f(e)?)),
}
}
}
impl<T, E> Event<Result<T, E>> {
pub fn transpose(self) -> Result<Event<T>, E> {
match self {
Event::Protocol(e) => Ok(Event::Protocol(e)),
Event::Event(Ok(e)) => Ok(Event::Event(e)),
Event::Event(Err(e)) => Err(e),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ProtocolEvent {
#[serde(rename = "heartbeat")]
Heartbeat,
#[serde(rename = "phx_close")]
Close,
#[serde(rename = "phx_error")]
Error,
#[serde(rename = "phx_join")]
Join,
#[serde(rename = "phx_reply")]
Reply,
#[serde(rename = "phx_leave")]
Leave,
}