atm0s_small_p2p/
stream.rs1use std::{fmt::Debug, marker::PhantomData};
2use std::{
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use anyhow::anyhow;
9use serde::{de::DeserializeOwned, Serialize};
10use tokio_util::codec::LengthDelimitedCodec;
11use tokio_util::codec::{Decoder, Encoder};
12
13use quinn::{RecvStream, SendStream};
14use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
15
16#[derive(Debug)]
17pub struct P2pQuicStream {
18 read: RecvStream,
19 write: SendStream,
20}
21
22impl PartialEq for P2pQuicStream {
23 fn eq(&self, other: &Self) -> bool {
24 self.read.id() == other.read.id() && self.write.id() == other.write.id()
25 }
26}
27
28impl Eq for P2pQuicStream {}
29
30impl P2pQuicStream {
31 pub fn new(read: RecvStream, write: SendStream) -> Self {
32 Self { read, write }
33 }
34}
35
36impl AsyncRead for P2pQuicStream {
37 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<io::Result<()>> {
38 Pin::new(&mut self.get_mut().read).poll_read(cx, buf)
39 }
40}
41
42impl AsyncWrite for P2pQuicStream {
43 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
44 let w: &mut (dyn AsyncWrite + Unpin) = &mut self.get_mut().write;
45 Pin::new(w).poll_write(cx, buf)
46 }
47
48 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
49 Pin::new(&mut self.get_mut().write).poll_flush(cx)
50 }
51
52 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
53 Pin::new(&mut self.get_mut().write).poll_shutdown(cx)
54 }
55}
56
57pub struct BincodeCodec<Item> {
58 length_decode: LengthDelimitedCodec,
59 _tmp: PhantomData<Item>,
60}
61
62impl<Item> Default for BincodeCodec<Item> {
63 fn default() -> Self {
64 Self {
65 length_decode: LengthDelimitedCodec::default(),
66 _tmp: Default::default(),
67 }
68 }
69}
70
71impl<Item: Serialize> Encoder<Item> for BincodeCodec<Item> {
72 type Error = std::io::Error;
73
74 fn encode(&mut self, item: Item, dst: &mut tokio_util::bytes::BytesMut) -> Result<(), Self::Error> {
75 let data: Vec<u8> = bincode::serialize(&item).map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "bincode serialize failure"))?;
76 self.length_decode.encode(data.into(), dst)
77 }
78}
79
80impl<Item: DeserializeOwned + Debug> Decoder for BincodeCodec<Item> {
81 type Error = std::io::Error;
82 type Item = Item;
83
84 fn decode(&mut self, src: &mut tokio_util::bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
85 match self.length_decode.decode(src)? {
86 Some(buf) => Ok(Some(
87 bincode::deserialize(&buf).map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "bincode deserialize failure"))?,
88 )),
89 None => Ok(None),
90 }
91 }
92}
93
94pub async fn wait_object<R: AsyncRead + Unpin, O: DeserializeOwned, const MAX_SIZE: usize>(reader: &mut R) -> anyhow::Result<O> {
95 let mut len_buf = [0; 2];
96 let mut data_buf = [0; MAX_SIZE];
97 reader.read_exact(&mut len_buf).await?;
98 let handshake_len = u16::from_be_bytes([len_buf[0], len_buf[1]]) as usize;
99 if handshake_len > data_buf.len() {
100 return Err(anyhow!("packet to big {} vs {MAX_SIZE}", data_buf.len()));
101 }
102
103 reader.read_exact(&mut data_buf[0..handshake_len]).await?;
104
105 Ok(bincode::deserialize(&data_buf[0..handshake_len])?)
106}
107
108pub async fn write_object<W: AsyncWrite + Send + Unpin, O: Serialize, const MAX_SIZE: usize>(writer: &mut W, object: &O) -> anyhow::Result<()> {
109 let data_buf: Vec<u8> = bincode::serialize(&object).expect("Should convert to binary");
110 if data_buf.len() > MAX_SIZE {
111 return Err(anyhow!("buffer to big {} vs {MAX_SIZE}", data_buf.len()));
112 }
113 let len_buf = (data_buf.len() as u16).to_be_bytes();
114
115 writer.write_all(&len_buf).await?;
116 writer.write_all(&data_buf).await?;
117 Ok(())
118}