atm0s_small_p2p/
stream.rs

1use std::{fmt::Debug, marker::PhantomData};
2use std::{
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use anyhow::anyhow;
9use serde::{de::DeserializeOwned, Serialize};
10use tokio_util::codec::LengthDelimitedCodec;
11use tokio_util::codec::{Decoder, Encoder};
12
13use quinn::{RecvStream, SendStream};
14use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
15
16#[derive(Debug)]
17pub struct P2pQuicStream {
18    read: RecvStream,
19    write: SendStream,
20}
21
22impl PartialEq for P2pQuicStream {
23    fn eq(&self, other: &Self) -> bool {
24        self.read.id() == other.read.id() && self.write.id() == other.write.id()
25    }
26}
27
28impl Eq for P2pQuicStream {}
29
30impl P2pQuicStream {
31    pub fn new(read: RecvStream, write: SendStream) -> Self {
32        Self { read, write }
33    }
34}
35
36impl AsyncRead for P2pQuicStream {
37    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<io::Result<()>> {
38        Pin::new(&mut self.get_mut().read).poll_read(cx, buf)
39    }
40}
41
42impl AsyncWrite for P2pQuicStream {
43    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
44        let w: &mut (dyn AsyncWrite + Unpin) = &mut self.get_mut().write;
45        Pin::new(w).poll_write(cx, buf)
46    }
47
48    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
49        Pin::new(&mut self.get_mut().write).poll_flush(cx)
50    }
51
52    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
53        Pin::new(&mut self.get_mut().write).poll_shutdown(cx)
54    }
55}
56
57pub struct BincodeCodec<Item> {
58    length_decode: LengthDelimitedCodec,
59    _tmp: PhantomData<Item>,
60}
61
62impl<Item> Default for BincodeCodec<Item> {
63    fn default() -> Self {
64        Self {
65            length_decode: LengthDelimitedCodec::default(),
66            _tmp: Default::default(),
67        }
68    }
69}
70
71impl<Item: Serialize> Encoder<Item> for BincodeCodec<Item> {
72    type Error = std::io::Error;
73
74    fn encode(&mut self, item: Item, dst: &mut tokio_util::bytes::BytesMut) -> Result<(), Self::Error> {
75        let data: Vec<u8> = bincode::serialize(&item).map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "bincode serialize failure"))?;
76        self.length_decode.encode(data.into(), dst)
77    }
78}
79
80impl<Item: DeserializeOwned + Debug> Decoder for BincodeCodec<Item> {
81    type Error = std::io::Error;
82    type Item = Item;
83
84    fn decode(&mut self, src: &mut tokio_util::bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
85        match self.length_decode.decode(src)? {
86            Some(buf) => Ok(Some(
87                bincode::deserialize(&buf).map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "bincode deserialize failure"))?,
88            )),
89            None => Ok(None),
90        }
91    }
92}
93
94pub async fn wait_object<R: AsyncRead + Unpin, O: DeserializeOwned, const MAX_SIZE: usize>(reader: &mut R) -> anyhow::Result<O> {
95    let mut len_buf = [0; 2];
96    let mut data_buf = [0; MAX_SIZE];
97    reader.read_exact(&mut len_buf).await?;
98    let handshake_len = u16::from_be_bytes([len_buf[0], len_buf[1]]) as usize;
99    if handshake_len > data_buf.len() {
100        return Err(anyhow!("packet to big {} vs {MAX_SIZE}", data_buf.len()));
101    }
102
103    reader.read_exact(&mut data_buf[0..handshake_len]).await?;
104
105    Ok(bincode::deserialize(&data_buf[0..handshake_len])?)
106}
107
108pub async fn write_object<W: AsyncWrite + Send + Unpin, O: Serialize, const MAX_SIZE: usize>(writer: &mut W, object: &O) -> anyhow::Result<()> {
109    let data_buf: Vec<u8> = bincode::serialize(&object).expect("Should convert to binary");
110    if data_buf.len() > MAX_SIZE {
111        return Err(anyhow!("buffer to big {} vs {MAX_SIZE}", data_buf.len()));
112    }
113    let len_buf = (data_buf.len() as u16).to_be_bytes();
114
115    writer.write_all(&len_buf).await?;
116    writer.write_all(&data_buf).await?;
117    Ok(())
118}