p2panda_sync/
cbor.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Utility methods to encode or decode wire protocol messages in [CBOR] format.
4//!
5//! [CBOR]: https://cbor.io/
6use std::marker::PhantomData;
7
8use futures::{AsyncRead, AsyncWrite, Sink, Stream};
9use p2panda_core::cbor::{DecodeError, decode_cbor, encode_cbor};
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use tokio_util::bytes::{Buf, BytesMut};
13use tokio_util::codec::{Decoder, Encoder};
14use tokio_util::codec::{FramedRead, FramedWrite};
15use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};
16
17use crate::SyncError;
18
19/// Implementation of the tokio codec traits to encode- and decode CBOR data as a stream.
20///
21/// CBOR allows message framing based on initial "headers" for each "data item", which indicate the
22/// type of data and the expected "body" length to be followed. A stream-based decoder can attempt
23/// parsing these headers and then reason about if it has enough information to proceed.
24///
25/// Read more on CBOR in streaming applications here:
26/// <https://www.rfc-editor.org/rfc/rfc8949.html#section-5.1>
27#[derive(Clone, Debug)]
28pub struct CborCodec<T> {
29    _phantom: PhantomData<T>,
30}
31
32impl<M> CborCodec<M> {
33    pub fn new() -> Self {
34        CborCodec {
35            _phantom: PhantomData {},
36        }
37    }
38}
39
40impl<M> Default for CborCodec<M> {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl<T> Encoder<T> for CborCodec<T>
47where
48    T: Serialize,
49{
50    type Error = SyncError;
51
52    /// Encodes a serializable item into CBOR bytes and adds them to the buffer.
53    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
54        let bytes = encode_cbor(&item).map_err(|err| {
55            // When we've failed encoding our _own_ messages something seriously went wrong.
56            SyncError::Critical(format!("CBOR codec failed encoding message, {err}"))
57        })?;
58        // Append the encoded CBOR bytes to the buffer instead of replacing it, we might already
59        // have previously encoded items in it.
60        dst.extend_from_slice(&bytes);
61        Ok(())
62    }
63}
64
65impl<T> Decoder for CborCodec<T>
66where
67    T: Serialize + DeserializeOwned,
68{
69    type Item = T;
70    type Error = SyncError;
71
72    /// CBOR decoder method taking as an argument the bytes that have been read so far; when called,
73    /// it will be in one of the following situations:
74    ///
75    /// 1. The buffer contains less than a full frame.
76    /// 2. The buffer contains exactly a full frame.
77    /// 3. The buffer contains more than a full frame.
78    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
79        // Keep a reference of the buffer to not advance the main buffer itself (yet).
80        let mut bytes: &[u8] = src.as_ref();
81        let starting = bytes.len();
82
83        // Attempt decoding the buffer and remember how many bytes we've advanced it doing that.
84        //
85        // This will succeed in case 2. and 3.
86        let result: Result<Self::Item, _> = decode_cbor(&mut bytes);
87        let ending = bytes.len();
88
89        match result {
90            Ok(item) => {
91                // We've successfully read one full frame from the buffer. We're finally
92                // advancing it for the next decode iteration and yield the resulting data item to
93                // the stream.
94                src.advance(starting - ending);
95                Ok(Some(item))
96            }
97            // Note that the buffer is not further advanced in case of an error.
98            Err(ref error) => match error {
99                DecodeError::Io(err) => {
100                    if err.kind() == std::io::ErrorKind::UnexpectedEof {
101                        // EOF errors indicate that our buffer doesn't contain enough data to
102                        // decode a whole CBOR frame. We're yielding no data item and re-try
103                        // decoding in the next iteration.
104                        //
105                        // This is handling case 1.
106                        Ok(None)
107                    } else {
108                        // An I/O error during decoding usually indicates something wrong with our
109                        // system (lack of system memory etc.).
110                        Err(SyncError::Critical(format!(
111                            "CBOR codec failed decoding message due to i/o error, {err}"
112                        )))
113                    }
114                }
115                err => Err(SyncError::InvalidEncoding(err.to_string())),
116            },
117        }
118    }
119}
120
121/// Returns a reader for your data type, automatically decoding CBOR byte-streams and handling the
122/// message framing.
123///
124/// This can be used in various sync protocol implementations where we need to receive data via a
125/// wire protocol between two peers.
126///
127/// This is a convenience method if you want to use CBOR encoding and serde to handle your wire
128/// protocol message encoding and framing without implementing it yourself. If you're interested in
129/// your own approach you can either implement your own `FramedRead` or `Sink`.
130pub fn into_cbor_stream<'a, M>(
131    rx: Box<&'a mut (dyn AsyncRead + Send + Unpin)>,
132) -> impl Stream<Item = Result<M, SyncError>> + Send + Unpin + 'a
133where
134    M: for<'de> Deserialize<'de> + Serialize + Send + 'a,
135{
136    FramedRead::new(rx.compat(), CborCodec::<M>::new())
137}
138
139/// Returns a writer for your data type, automatically encoding it as CBOR for a framed
140/// byte-stream.
141///
142/// This can be used in various sync protocol implementations where we need to send data via a wire
143/// protocol between two peers.
144///
145/// This is a convenience method if you want to use CBOR encoding and serde to handle your wire
146/// protocol message encoding and framing without implementing it yourself. If you're interested in
147/// your own approach you can either implement your own `FramedWrite` or `Stream`.
148pub fn into_cbor_sink<'a, M>(
149    tx: Box<&'a mut (dyn AsyncWrite + Send + Unpin)>,
150) -> impl Sink<M, Error = SyncError> + Send + Unpin + 'a
151where
152    M: for<'de> Deserialize<'de> + Serialize + Send + 'a,
153{
154    FramedWrite::new(tx.compat_write(), CborCodec::<M>::new())
155}
156
157#[cfg(test)]
158mod tests {
159    use futures::FutureExt;
160    use tokio::io::AsyncWriteExt;
161    use tokio_stream::StreamExt;
162    use tokio_util::codec::FramedRead;
163
164    use super::CborCodec;
165
166    #[tokio::test]
167    async fn decoding_exactly_one_frame() {
168        let (mut tx, rx) = tokio::io::duplex(64);
169        let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
170
171        // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
172        // Hexadecimal representation = 65
173        // Decimal representation = 101
174        tx.write_all(&[101]).await.unwrap();
175
176        // CBOR body, the actual string.
177        tx.write_all("hello".as_bytes()).await.unwrap();
178
179        let message = stream.next().await;
180        assert_eq!(message, Some(Ok("hello".into())));
181    }
182
183    #[tokio::test]
184    async fn decoding_more_than_one_frame() {
185        let (mut tx, rx) = tokio::io::duplex(64);
186        let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
187
188        // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
189        // Hexadecimal representation = 65
190        // Decimal representation = 101
191        tx.write_all(&[101]).await.unwrap();
192
193        // CBOR body, the actual string.
194        tx.write_all("hello".as_bytes()).await.unwrap();
195
196        // Another CBOR header (frame) for another message (length of 9).
197        // Hexadecimal representation = 69
198        // Decimal representation = 105
199        tx.write_all(&[105]).await.unwrap();
200        tx.write_all("aquariums".as_bytes()).await.unwrap();
201
202        let message = stream.next().await;
203        assert_eq!(message, Some(Ok("hello".into())));
204
205        let message = stream.next().await;
206        assert_eq!(message, Some(Ok("aquariums".into())));
207    }
208
209    #[tokio::test]
210    async fn decoding_incomplete_frame() {
211        let (mut tx, rx) = tokio::io::duplex(64);
212        let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
213
214        // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
215        // Hexadecimal representation = 65
216        // Decimal representation = 101
217        tx.write_all(&[101]).await.unwrap();
218
219        // Attempt to decode an incomplete CBOR frame, the decoder should not yield anything.
220        let message = stream.next().now_or_never();
221        assert_eq!(message, None);
222
223        // Complete the CBOR data item in the buffer.
224        tx.write_all("hello".as_bytes()).await.unwrap();
225
226        let message = stream.next().await;
227        assert_eq!(message, Some(Ok("hello".into())));
228    }
229}