jetstream_wireformat/
wire_format_extensions.rs

1// Copyright (c) 2024, Sevki <s@sevki.io>
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::io::{self};
6
7use bytes::Bytes;
8
9use super::WireFormat;
10
11pub trait AsyncWireFormat: std::marker::Sized {
12    fn encode_async<W: AsyncWireFormat + Unpin + Send>(
13        self,
14        writer: &mut W,
15    ) -> impl std::future::Future<Output = io::Result<()>> + Send;
16    fn decode_async<R: AsyncWireFormat + Unpin + Send>(
17        reader: &mut R,
18    ) -> impl std::future::Future<Output = io::Result<Self>> + Send;
19}
20
21#[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
22pub mod tokio {
23    use std::{future::Future, io};
24
25    use tokio::io::{AsyncRead, AsyncWrite};
26
27    use crate::WireFormat;
28    /// Extension trait for asynchronous wire format encoding and decoding.
29    pub trait AsyncWireFormatExt
30    where
31        Self: WireFormat + Send,
32    {
33        /// Encodes the object asynchronously into the provided writer.
34        ///
35        /// # Arguments
36        ///
37        /// * `writer` - The writer to encode the object into.n
38        ///
39        /// # Returns
40        ///
41        /// A future that resolves to an `io::Result<()>` indicating the success or failure of the encoding operation.
42        fn encode_async<W>(
43            self,
44            writer: W,
45        ) -> impl Future<Output = io::Result<()>>
46        where
47            Self: Sync + Sized,
48            W: AsyncWrite + Unpin + Send,
49        {
50            let mut writer = tokio_util::io::SyncIoBridge::new(writer);
51            async {
52                tokio::task::block_in_place(move || self.encode(&mut writer))
53            }
54        }
55
56        /// Decodes an object asynchronously from the provided reader.
57        ///
58        /// # Arguments
59        ///
60        /// * `reader` - The reader to decode the object from.
61        ///
62        /// # Returns
63        ///
64        /// A future that resolves to an `io::Result<Self>` indicating the success or failure of the decoding operation.
65        fn decode_async<R>(
66            reader: R,
67        ) -> impl Future<Output = io::Result<Self>> + Send
68        where
69            Self: Sync + Sized,
70            R: AsyncRead + Unpin + Send,
71        {
72            let mut reader = tokio_util::io::SyncIoBridge::new(reader);
73            async {
74                tokio::task::block_in_place(move || Self::decode(&mut reader))
75            }
76        }
77    }
78    /// Implements the `AsyncWireFormatExt` trait for types that implement the `WireFormat` trait and can be sent across threads.
79    impl<T: WireFormat + Send> AsyncWireFormatExt for T {}
80}
81
82/// A trait for converting types to and from a wire format.
83pub trait ConvertWireFormat: WireFormat {
84    /// Converts the type to a byte representation.
85    ///
86    /// # Returns
87    ///
88    /// A `Bytes` object representing the byte representation of the type.
89    fn to_bytes(&self) -> Bytes;
90
91    /// Converts a byte buffer to the type.
92    ///
93    /// # Arguments
94    ///
95    /// * `buf` - A mutable reference to a `Bytes` object containing the byte buffer.
96    ///
97    /// # Returns
98    ///
99    /// A `Result` containing the converted type or an `std::io::Error` if the conversion fails.
100    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error>
101    where
102        Self: Sized;
103
104    /// AsRef<[u8]> for the type.
105    ///
106    /// # Returns
107    ///
108    /// A reference to the byte representation of the type.
109    fn as_bytes(&self) -> Vec<u8> {
110        self.to_bytes().to_vec()
111    }
112}
113
114/// Implements the `ConvertWireFormat` trait for types that implement `jetstream_p9::WireFormat`.
115/// This trait provides methods for converting the type to and from bytes.
116impl<T> ConvertWireFormat for T
117where
118    T: WireFormat,
119{
120    /// Converts the type to bytes.
121    /// Returns a `Bytes` object containing the encoded bytes.
122    fn to_bytes(&self) -> Bytes {
123        let mut buf = vec![];
124        let res = self.encode(&mut buf);
125        if let Err(e) = res {
126            panic!("Failed to encode: {}", e);
127        }
128        Bytes::from(buf)
129    }
130
131    /// Converts bytes to the type.
132    /// Returns a `Result` containing the decoded type or an `std::io::Error` if decoding fails.
133    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error> {
134        let buf = buf.to_vec();
135        T::decode(&mut buf.as_slice())
136    }
137}
138
139impl WireFormat for Ipv4Addr {
140    fn byte_size(&self) -> u32 {
141        self.octets().len() as u32
142    }
143
144    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
145        writer.write_all(&self.octets())
146    }
147
148    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
149        let mut buf = [0u8; 4];
150        reader.read_exact(&mut buf)?;
151        Ok(Ipv4Addr::from(buf))
152    }
153}
154
155impl WireFormat for Ipv6Addr {
156    fn byte_size(&self) -> u32 {
157        self.octets().len() as u32
158    }
159
160    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
161        writer.write_all(&self.octets())
162    }
163
164    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
165        let mut buf = [0u8; 16];
166        reader.read_exact(&mut buf)?;
167        Ok(Ipv6Addr::from(buf))
168    }
169}
170
171impl WireFormat for SocketAddrV4 {
172    fn byte_size(&self) -> u32 {
173        self.ip().byte_size() + 2
174    }
175
176    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
177        self.ip().encode(writer)?;
178        self.port().encode(writer)
179    }
180
181    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
182        self::Ipv4Addr::decode(reader).and_then(|ip| {
183            u16::decode(reader).map(|port| SocketAddrV4::new(ip, port))
184        })
185    }
186}
187
188impl WireFormat for SocketAddrV6 {
189    fn byte_size(&self) -> u32 {
190        self.ip().byte_size() + 2
191    }
192
193    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
194        self.ip().encode(writer)?;
195        self.port().encode(writer)
196    }
197
198    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
199        self::Ipv6Addr::decode(reader).and_then(|ip| {
200            u16::decode(reader).map(|port| SocketAddrV6::new(ip, port, 0, 0))
201        })
202    }
203}
204
205impl WireFormat for SocketAddr {
206    fn byte_size(&self) -> u32 {
207        1 + match self {
208            SocketAddr::V4(socket_addr_v4) => socket_addr_v4.byte_size(),
209            SocketAddr::V6(socket_addr_v6) => socket_addr_v6.byte_size(),
210        }
211    }
212
213    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()>
214    where
215        Self: Sized,
216    {
217        match self {
218            SocketAddr::V4(socket_addr_v4) => {
219                writer.write_all(&[0])?;
220                socket_addr_v4.encode(writer)
221            }
222            SocketAddr::V6(socket_addr_v6) => {
223                writer.write_all(&[1])?;
224                socket_addr_v6.encode(writer)
225            }
226        }
227    }
228
229    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self>
230    where
231        Self: Sized,
232    {
233        let mut buf = [0u8; 1];
234        reader.read_exact(&mut buf)?;
235        match buf[0] {
236            0 => Ok(SocketAddr::V4(SocketAddrV4::decode(reader)?)),
237            1 => Ok(SocketAddr::V6(SocketAddrV6::decode(reader)?)),
238            _ => Err(std::io::Error::new(
239                std::io::ErrorKind::InvalidData,
240                "Invalid address type",
241            )),
242        }
243    }
244}