jetstream_wireformat/
wire_format_extensions.rs1use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::io::{self};
6
7use bytes::Bytes;
8
9use super::WireFormat;
10
11pub trait AsyncWireFormat: std::marker::Sized {
12 fn encode_async<W: AsyncWireFormat + Unpin + Send>(
13 self,
14 writer: &mut W,
15 ) -> impl std::future::Future<Output = io::Result<()>> + Send;
16 fn decode_async<R: AsyncWireFormat + Unpin + Send>(
17 reader: &mut R,
18 ) -> impl std::future::Future<Output = io::Result<Self>> + Send;
19}
20
21#[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
22pub mod tokio {
23 use std::{future::Future, io};
24
25 use tokio::io::{AsyncRead, AsyncWrite};
26
27 use crate::WireFormat;
28 pub trait AsyncWireFormatExt
30 where
31 Self: WireFormat + Send,
32 {
33 fn encode_async<W>(
43 self,
44 writer: W,
45 ) -> impl Future<Output = io::Result<()>>
46 where
47 Self: Sync + Sized,
48 W: AsyncWrite + Unpin + Send,
49 {
50 let mut writer = tokio_util::io::SyncIoBridge::new(writer);
51 async {
52 tokio::task::block_in_place(move || self.encode(&mut writer))
53 }
54 }
55
56 fn decode_async<R>(
66 reader: R,
67 ) -> impl Future<Output = io::Result<Self>> + Send
68 where
69 Self: Sync + Sized,
70 R: AsyncRead + Unpin + Send,
71 {
72 let mut reader = tokio_util::io::SyncIoBridge::new(reader);
73 async {
74 tokio::task::block_in_place(move || Self::decode(&mut reader))
75 }
76 }
77 }
78 impl<T: WireFormat + Send> AsyncWireFormatExt for T {}
80}
81
82pub trait ConvertWireFormat: WireFormat {
84 fn to_bytes(&self) -> Bytes;
90
91 fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error>
101 where
102 Self: Sized;
103
104 fn as_bytes(&self) -> Vec<u8> {
110 self.to_bytes().to_vec()
111 }
112}
113
114impl<T> ConvertWireFormat for T
117where
118 T: WireFormat,
119{
120 fn to_bytes(&self) -> Bytes {
123 let mut buf = vec![];
124 let res = self.encode(&mut buf);
125 if let Err(e) = res {
126 panic!("Failed to encode: {}", e);
127 }
128 Bytes::from(buf)
129 }
130
131 fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error> {
134 let buf = buf.to_vec();
135 T::decode(&mut buf.as_slice())
136 }
137}
138
139impl WireFormat for Ipv4Addr {
140 fn byte_size(&self) -> u32 {
141 self.octets().len() as u32
142 }
143
144 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
145 writer.write_all(&self.octets())
146 }
147
148 fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
149 let mut buf = [0u8; 4];
150 reader.read_exact(&mut buf)?;
151 Ok(Ipv4Addr::from(buf))
152 }
153}
154
155impl WireFormat for Ipv6Addr {
156 fn byte_size(&self) -> u32 {
157 self.octets().len() as u32
158 }
159
160 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
161 writer.write_all(&self.octets())
162 }
163
164 fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
165 let mut buf = [0u8; 16];
166 reader.read_exact(&mut buf)?;
167 Ok(Ipv6Addr::from(buf))
168 }
169}
170
171impl WireFormat for SocketAddrV4 {
172 fn byte_size(&self) -> u32 {
173 self.ip().byte_size() + 2
174 }
175
176 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
177 self.ip().encode(writer)?;
178 self.port().encode(writer)
179 }
180
181 fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
182 self::Ipv4Addr::decode(reader).and_then(|ip| {
183 u16::decode(reader).map(|port| SocketAddrV4::new(ip, port))
184 })
185 }
186}
187
188impl WireFormat for SocketAddrV6 {
189 fn byte_size(&self) -> u32 {
190 self.ip().byte_size() + 2
191 }
192
193 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
194 self.ip().encode(writer)?;
195 self.port().encode(writer)
196 }
197
198 fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
199 self::Ipv6Addr::decode(reader).and_then(|ip| {
200 u16::decode(reader).map(|port| SocketAddrV6::new(ip, port, 0, 0))
201 })
202 }
203}
204
205impl WireFormat for SocketAddr {
206 fn byte_size(&self) -> u32 {
207 1 + match self {
208 SocketAddr::V4(socket_addr_v4) => socket_addr_v4.byte_size(),
209 SocketAddr::V6(socket_addr_v6) => socket_addr_v6.byte_size(),
210 }
211 }
212
213 fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()>
214 where
215 Self: Sized,
216 {
217 match self {
218 SocketAddr::V4(socket_addr_v4) => {
219 writer.write_all(&[0])?;
220 socket_addr_v4.encode(writer)
221 }
222 SocketAddr::V6(socket_addr_v6) => {
223 writer.write_all(&[1])?;
224 socket_addr_v6.encode(writer)
225 }
226 }
227 }
228
229 fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self>
230 where
231 Self: Sized,
232 {
233 let mut buf = [0u8; 1];
234 reader.read_exact(&mut buf)?;
235 match buf[0] {
236 0 => Ok(SocketAddr::V4(SocketAddrV4::decode(reader)?)),
237 1 => Ok(SocketAddr::V6(SocketAddrV6::decode(reader)?)),
238 _ => Err(std::io::Error::new(
239 std::io::ErrorKind::InvalidData,
240 "Invalid address type",
241 )),
242 }
243 }
244}