use std::marker::PhantomData;
use futures_util::{Sink, Stream};
use p2panda_core::cbor::{DecodeError, EncodeError, decode_cbor, encode_cbor};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::bytes::{Buf, BytesMut};
use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
#[derive(Clone, Debug)]
pub struct CborCodec<T> {
_phantom: PhantomData<T>,
}
impl<M> CborCodec<M> {
pub fn new() -> Self {
CborCodec {
_phantom: PhantomData {},
}
}
}
impl<M> Default for CborCodec<M> {
fn default() -> Self {
Self::new()
}
}
impl<T> Encoder<T> for CborCodec<T>
where
T: Serialize,
{
type Error = CborCodecError;
fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
let bytes = encode_cbor(&item)?;
dst.extend_from_slice(&bytes);
Ok(())
}
}
impl<T> Decoder for CborCodec<T>
where
T: Serialize + DeserializeOwned,
{
type Item = T;
type Error = CborCodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut bytes: &[u8] = src.as_ref();
let starting = bytes.len();
let result: Result<Self::Item, _> = decode_cbor(&mut bytes);
let ending = bytes.len();
match result {
Ok(item) => {
src.advance(starting - ending);
Ok(Some(item))
}
Err(error) => match error {
DecodeError::Io(err) => {
if err.kind() == std::io::ErrorKind::UnexpectedEof {
Ok(None)
} else {
Err(CborCodecError::IO(format!(
"CBOR codec failed decoding message due to i/o error, {err}"
)))
}
}
err => Err(CborCodecError::Decode(err)),
},
}
}
}
pub fn into_cbor_stream<M, T>(
rx: T,
) -> impl Stream<Item = Result<M, CborCodecError>> + Unpin + use<M, T>
where
M: for<'de> Deserialize<'de> + Serialize + 'static,
T: AsyncRead + Unpin + 'static,
{
FramedRead::new(rx, CborCodec::<M>::new())
}
pub fn into_cbor_sink<M, T>(tx: T) -> impl Sink<M, Error = CborCodecError>
where
M: for<'de> Deserialize<'de> + Serialize + 'static,
T: AsyncWrite + Unpin + 'static,
{
FramedWrite::new(tx, CborCodec::<M>::new())
}
#[derive(Debug, Error)]
pub enum CborCodecError {
#[error(transparent)]
Decode(#[from] DecodeError),
#[error(transparent)]
Encode(#[from] EncodeError),
#[error("{0}")]
IO(String),
#[error("{0}")]
BrokenPipe(String),
}
impl From<std::io::Error> for CborCodecError {
fn from(err: std::io::Error) -> Self {
match err.kind() {
std::io::ErrorKind::BrokenPipe => Self::BrokenPipe("broken pipe".into()),
_ => Self::IO(format!("internal i/o stream error {err}")),
}
}
}
#[cfg(test)]
mod tests {
use futures_util::{FutureExt, SinkExt, StreamExt};
use p2panda_core::{Body, Hash, Header, SigningKey, Timestamp};
use tokio::io::AsyncWriteExt;
use tokio_util::codec::FramedRead;
use super::{CborCodec, into_cbor_sink, into_cbor_stream};
#[tokio::test]
async fn decoding_exactly_one_frame() {
let (mut tx, rx) = tokio::io::duplex(64);
let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
tx.write_all(&[101]).await.unwrap();
tx.write_all("hello".as_bytes()).await.unwrap();
let message = stream.next().await;
assert_eq!(message.unwrap().unwrap(), "hello".to_string());
}
#[tokio::test]
async fn decoding_more_than_one_frame() {
let (mut tx, rx) = tokio::io::duplex(64);
let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
tx.write_all(&[101]).await.unwrap();
tx.write_all("hello".as_bytes()).await.unwrap();
tx.write_all(&[105]).await.unwrap();
tx.write_all("aquariums".as_bytes()).await.unwrap();
let message = stream.next().await;
assert_eq!(message.unwrap().unwrap(), "hello".to_string());
let message = stream.next().await;
assert_eq!(message.unwrap().unwrap(), "aquariums".to_string());
}
#[tokio::test]
async fn decoding_incomplete_frame() {
let (mut tx, rx) = tokio::io::duplex(64);
let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
tx.write_all(&[101]).await.unwrap();
let message = stream.next().now_or_never();
assert!(message.is_none());
tx.write_all("hello".as_bytes()).await.unwrap();
let message = stream.next().await;
assert_eq!(message.unwrap().unwrap(), "hello".to_string());
}
#[tokio::test]
async fn operations_stream() {
type Payload = (Header<()>, Option<Body>);
fn create_operation(
signing_key: &SigningKey,
body: &[u8],
seq_num: u64,
backlink: Option<Hash>,
) -> Payload {
let body = Body::from(body);
let mut header = Header {
version: 1,
verifying_key: signing_key.verifying_key(),
signature: None,
payload_size: body.size(),
payload_hash: Some(body.hash()),
timestamp: Timestamp::now().into(),
seq_num,
backlink,
extensions: (),
};
header.sign(signing_key);
(header, Some(body))
}
let (tx_inner, rx_inner) = tokio::io::duplex(64);
let mut tx = into_cbor_sink::<Payload, _>(tx_inner);
let mut rx = into_cbor_stream::<Payload, _>(rx_inner);
tokio::task::spawn(async move {
let signing_key = SigningKey::generate();
let mut seq_num = 0;
let mut backlink = None;
for _ in 0..100 {
let (header, body) =
create_operation(&signing_key, b"boom boom boom", seq_num, backlink);
seq_num += 1;
backlink = Some(header.hash());
tx.send((header, body)).await.unwrap();
}
});
let mut i = 1;
loop {
if let Some(message) = rx.next().await {
if let Err(err) = message {
panic!("{err}");
}
i += 1;
if i == 100 {
break;
}
}
}
}
}