jetstream_wireformat/
wire_format_extensions.rs

1// Copyright (c) 2024, Sevki <s@sevki.io>
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4use std::{
5    io::{self},
6    net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
7};
8
9use bytes::Bytes;
10
11use super::WireFormat;
12
13pub trait AsyncWireFormat: std::marker::Sized {
14    fn encode_async<W: AsyncWireFormat + Unpin + Send>(
15        self,
16        writer: &mut W,
17    ) -> impl std::future::Future<Output = io::Result<()>> + Send;
18    fn decode_async<R: AsyncWireFormat + Unpin + Send>(
19        reader: &mut R,
20    ) -> impl std::future::Future<Output = io::Result<Self>> + Send;
21}
22
23#[cfg(not(target_arch = "wasm32"))]
24pub mod tokio {
25    use std::{future::Future, io};
26
27    use tokio::io::{AsyncRead, AsyncWrite};
28
29    use crate::WireFormat;
30    /// Extension trait for asynchronous wire format encoding and decoding.
31    pub trait AsyncWireFormatExt
32    where
33        Self: WireFormat + Send,
34    {
35        /// Encodes the object asynchronously into the provided writer.
36        ///
37        /// # Arguments
38        ///
39        /// * `writer` - The writer to encode the object into.n
40        ///
41        /// # Returns
42        ///
43        /// A future that resolves to an `io::Result<()>` indicating the success or failure of the encoding operation.
44        fn encode_async<W>(self, writer: W) -> impl Future<Output = io::Result<()>>
45        where
46            Self: Sync,
47            W: AsyncWrite + Unpin + Send,
48        {
49            let mut writer = tokio_util::io::SyncIoBridge::new(writer);
50            async { tokio::task::block_in_place(move || self.encode(&mut writer)) }
51        }
52
53        /// Decodes an object asynchronously from the provided reader.
54        ///
55        /// # Arguments
56        ///
57        /// * `reader` - The reader to decode the object from.
58        ///
59        /// # Returns
60        ///
61        /// A future that resolves to an `io::Result<Self>` indicating the success or failure of the decoding operation.
62        fn decode_async<R>(reader: R) -> impl Future<Output = io::Result<Self>> + Send
63        where
64            Self: Sync,
65            R: AsyncRead + Unpin + Send,
66        {
67            let mut reader = tokio_util::io::SyncIoBridge::new(reader);
68            async { tokio::task::block_in_place(move || Self::decode(&mut reader)) }
69        }
70    }
71    /// Implements the `AsyncWireFormatExt` trait for types that implement the `WireFormat` trait and can be sent across threads.
72    impl<T: WireFormat + Send> AsyncWireFormatExt for T {}
73}
74
75/// A trait for converting types to and from a wire format.
76pub trait ConvertWireFormat: WireFormat {
77    /// Converts the type to a byte representation.
78    ///
79    /// # Returns
80    ///
81    /// A `Bytes` object representing the byte representation of the type.
82    fn to_bytes(&self) -> Bytes;
83
84    /// Converts a byte buffer to the type.
85    ///
86    /// # Arguments
87    ///
88    /// * `buf` - A mutable reference to a `Bytes` object containing the byte buffer.
89    ///
90    /// # Returns
91    ///
92    /// A `Result` containing the converted type or an `std::io::Error` if the conversion fails.
93    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error>;
94
95    /// AsRef<[u8]> for the type.
96    ///
97    /// # Returns
98    ///
99    /// A reference to the byte representation of the type.
100    fn as_bytes(&self) -> Vec<u8> {
101        self.to_bytes().to_vec()
102    }
103}
104
105/// Implements the `ConvertWireFormat` trait for types that implement `jetstream_p9::WireFormat`.
106/// This trait provides methods for converting the type to and from bytes.
107impl<T> ConvertWireFormat for T
108where
109    T: WireFormat,
110{
111    /// Converts the type to bytes.
112    /// Returns a `Bytes` object containing the encoded bytes.
113    fn to_bytes(&self) -> Bytes {
114        let mut buf = vec![];
115        let res = self.encode(&mut buf);
116        if let Err(e) = res {
117            panic!("Failed to encode: {}", e);
118        }
119        Bytes::from(buf)
120    }
121
122    /// Converts bytes to the type.
123    /// Returns a `Result` containing the decoded type or an `std::io::Error` if decoding fails.
124    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error> {
125        let buf = buf.to_vec();
126        T::decode(&mut buf.as_slice())
127    }
128}
129
130#[cfg(feature = "std")]
131impl WireFormat for Ipv4Addr {
132    fn byte_size(&self) -> u32 {
133        self.octets().len() as u32
134    }
135
136    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
137        writer.write_all(&self.octets())
138    }
139
140    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
141        let mut buf = [0u8; 4];
142        reader.read_exact(&mut buf)?;
143        Ok(Ipv4Addr::from(buf))
144    }
145}
146
147#[cfg(feature = "std")]
148impl WireFormat for Ipv6Addr {
149    fn byte_size(&self) -> u32 {
150        self.octets().len() as u32
151    }
152
153    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
154        writer.write_all(&self.octets())
155    }
156
157    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
158        let mut buf = [0u8; 16];
159        reader.read_exact(&mut buf)?;
160        Ok(Ipv6Addr::from(buf))
161    }
162}
163
164#[cfg(feature = "std")]
165impl WireFormat for SocketAddrV4 {
166    fn byte_size(&self) -> u32 {
167        self.ip().byte_size() + 2
168    }
169
170    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
171        self.ip().encode(writer)?;
172        self.port().encode(writer)
173    }
174
175    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
176        self::Ipv4Addr::decode(reader)
177            .and_then(|ip| u16::decode(reader).map(|port| SocketAddrV4::new(ip, port)))
178    }
179}
180
181#[cfg(feature = "std")]
182impl WireFormat for SocketAddrV6 {
183    fn byte_size(&self) -> u32 {
184        self.ip().byte_size() + 2
185    }
186
187    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
188        self.ip().encode(writer)?;
189        self.port().encode(writer)
190    }
191
192    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
193        self::Ipv6Addr::decode(reader)
194            .and_then(|ip| u16::decode(reader).map(|port| SocketAddrV6::new(ip, port, 0, 0)))
195    }
196}