use crate::binary::AttachmentsBuilder;
use crate::client::Acknowledge;
use crate::error::{AckError, PayloadError};
use crate::marker::{BinaryMarker, HasBinary, NoBinary};
use crate::packet::Directive;
use crate::packet::DynAck;
use crate::payload::{DeserializePayload, SerializePayload, ack_from_json, ack_to_json};
use pin_project::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::Instant;
pub trait AckType: Sized {
type Binary: BinaryMarker;
}
impl AckType for () {
type Binary = NoBinary;
}
pub struct Ack<A>
where
A: AckType,
{
pub payload: A,
pub attachments: <A::Binary as BinaryMarker>::Attachments,
}
impl<A> std::fmt::Debug for Ack<A>
where
A: std::fmt::Debug + AckType,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut map = f.debug_map();
map.entry(&"payload", &self.payload);
A::Binary::format(&self.attachments, &mut map);
map.finish()
}
}
impl<A> TryFrom<DynAck> for Ack<A>
where
A: AckType + DeserializePayload,
{
type Error = AckError;
fn try_from(value: DynAck) -> Result<Self, AckError> {
let payload = ack_from_json(&value.payload)?;
let attachments = A::Binary::parse(value.attachments)?;
Ok(Self {
payload,
attachments,
})
}
}
impl<A> Acknowledge<A, NoBinary> for A
where
A: AckType<Binary = NoBinary> + SerializePayload,
{
fn into_directive(self, id: u64) -> Result<Directive, PayloadError> {
let payload = ack_to_json(&self)?.into();
Ok(Directive::Ack {
payload,
id,
attachments: None,
})
}
}
impl<F, A> Acknowledge<A, HasBinary> for F
where
F: FnOnce(&mut AttachmentsBuilder) -> A,
A: AckType<Binary = HasBinary> + SerializePayload,
{
fn into_directive(self, id: u64) -> Result<Directive, PayloadError> {
let mut builder = AttachmentsBuilder::new();
let payload = ack_to_json(&self(&mut builder))?.into();
Ok(Directive::Ack {
payload,
id,
attachments: Some(builder.finish()),
})
}
}
#[pin_project]
#[must_use = "AckHandle must be awaited to receive the ack"]
#[derive(Debug)]
pub struct AckHandle<A> {
#[pin]
rx: oneshot::Receiver<DynAck>,
marker: PhantomData<A>,
}
impl<A: AckType> AckHandle<A> {
pub(crate) fn new(rx: oneshot::Receiver<DynAck>) -> Self {
Self {
rx,
marker: PhantomData,
}
}
}
impl<A> AckHandle<A>
where
A: AckType + DeserializePayload,
{
pub async fn timeout(self, duration: Duration) -> Result<Ack<A>, AckError> {
tokio::time::timeout(duration, self).await?
}
pub async fn timeout_at(self, deadline: Instant) -> Result<Ack<A>, AckError> {
tokio::time::timeout_at(deadline, self).await?
}
}
impl<A> Future for AckHandle<A>
where
A: AckType + DeserializePayload,
{
type Output = Result<Ack<A>, AckError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().rx.poll(cx)?.map(Ack::try_from)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::marker::{AckMarker, HasAck, HasBinary};
use bytes::Bytes;
use bytestring::ByteString;
#[derive(Debug, PartialEq)]
struct BinaryBoolAck(bool);
impl AckType for BinaryBoolAck {
type Binary = HasBinary;
}
impl SerializePayload for BinaryBoolAck {
fn serialize_payload<S>(&self, seq: &mut S) -> std::result::Result<(), S::Error>
where
S: serde::ser::SerializeSeq,
{
seq.serialize_element(&self.0)?;
Ok(())
}
}
impl DeserializePayload for BinaryBoolAck {
fn deserialize_payload<'de, S>(seq: &mut S) -> std::result::Result<Self, S::Error>
where
S: serde::de::SeqAccess<'de>,
{
let v = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &"1 element"))?;
Ok(Self(v))
}
}
#[derive(Debug, PartialEq)]
struct BinaryUnitAck;
impl AckType for BinaryUnitAck {
type Binary = HasBinary;
}
impl SerializePayload for BinaryUnitAck {
fn serialize_payload<S>(&self, _seq: &mut S) -> std::result::Result<(), S::Error>
where
S: serde::ser::SerializeSeq,
{
Ok(())
}
}
impl DeserializePayload for BinaryUnitAck {
fn deserialize_payload<'de, S>(_seq: &mut S) -> std::result::Result<Self, S::Error>
where
S: serde::de::SeqAccess<'de>,
{
Ok(Self)
}
}
#[test]
fn serialize_unit_ack() {
assert_eq!(ack_to_json(&()).unwrap(), "[]");
}
#[test]
fn deserialize_unit_ack() {
assert_eq!(ack_from_json::<()>("[]").unwrap(), ());
}
#[test]
fn from_ack_with_binary() {
let attachment = Bytes::from_static(b"\xDE\xAD");
let ack = DynAck {
payload: ByteString::from_static("[true]"),
attachments: Some(vec![attachment.clone()]),
};
let ack: Ack<BinaryBoolAck> = ack.try_into().unwrap();
assert_eq!(ack.payload, BinaryBoolAck(true));
assert_eq!(ack.attachments.len(), 1);
assert_eq!(ack.attachments[0], attachment);
}
#[test]
fn from_ack_missing_binary_fails() {
let ack = DynAck {
payload: ByteString::from_static("[]"),
attachments: None,
};
let result: Result<Ack<BinaryUnitAck>, _> = ack.try_into();
assert!(result.is_err());
}
#[test]
fn from_ack_unexpected_binary_fails() {
let ack = DynAck {
payload: ByteString::from_static("[]"),
attachments: Some(vec![Bytes::from_static(b"x")]),
};
let result: Result<Ack<()>, _> = ack.try_into();
assert!(result.is_err());
}
#[test]
fn send_ack_into_directive_binary() {
let id = <HasAck<BinaryBoolAck>>::parse(Some(3)).unwrap();
let directive = Acknowledge::<BinaryBoolAck, HasBinary>::into_directive(
|builder: &mut AttachmentsBuilder| {
let _p = builder.attach(Bytes::from_static(b"\xCA\xFE"));
BinaryBoolAck(true)
},
id.get(),
)
.unwrap();
match directive {
Directive::Ack {
payload,
id,
attachments,
} => {
assert_eq!(&payload[..], "[true]");
assert_eq!(id, 3);
let att = attachments.expect("expected attachments");
assert_eq!(att.len(), 1);
assert_eq!(att[0], Bytes::from_static(b"\xCA\xFE"));
}
_ => panic!("expected Ack packet"),
}
}
#[tokio::test]
async fn ack_handle_closed_channel_errors() {
let (tx, rx) = oneshot::channel::<DynAck>();
let handle = AckHandle::<()>::new(rx);
drop(tx);
let result: Result<Ack<()>, _> = handle.await;
assert!(result.is_err());
}
}