Skip to main content

p2panda_net/
cbor.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Utility methods to encode or decode wire protocol messages in [CBOR] format.
4//!
5//! [CBOR]: https://cbor.io/
6use std::marker::PhantomData;
7
8use futures_util::{Sink, Stream};
9use p2panda_core::cbor::{DecodeError, EncodeError, decode_cbor, encode_cbor};
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13use tokio::io::{AsyncRead, AsyncWrite};
14use tokio_util::bytes::{Buf, BytesMut};
15use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
16
17/// Implementation of the tokio codec traits to encode- and decode CBOR data as a stream.
18///
19/// CBOR allows message framing based on initial "headers" for each "data item", which indicate the
20/// type of data and the expected "body" length to be followed. A stream-based decoder can attempt
21/// parsing these headers and then reason about if it has enough information to proceed.
22///
23/// Read more on CBOR in streaming applications here:
24/// <https://www.rfc-editor.org/rfc/rfc8949.html#section-5.1>
25#[derive(Clone, Debug)]
26pub struct CborCodec<T> {
27    _phantom: PhantomData<T>,
28}
29
30impl<M> CborCodec<M> {
31    pub fn new() -> Self {
32        CborCodec {
33            _phantom: PhantomData {},
34        }
35    }
36}
37
38impl<M> Default for CborCodec<M> {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl<T> Encoder<T> for CborCodec<T>
45where
46    T: Serialize,
47{
48    type Error = CborCodecError;
49
50    /// Encodes a serializable item into CBOR bytes and adds them to the buffer.
51    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
52        // NOTE: If we've failed encoding our _own_ messages something seriously went wrong.
53        let bytes = encode_cbor(&item)?;
54        // Append the encoded CBOR bytes to the buffer instead of replacing it, we might already
55        // have previously encoded items in it.
56        dst.extend_from_slice(&bytes);
57        Ok(())
58    }
59}
60
61impl<T> Decoder for CborCodec<T>
62where
63    T: Serialize + DeserializeOwned,
64{
65    type Item = T;
66    type Error = CborCodecError;
67
68    /// CBOR decoder method taking as an argument the bytes that have been read so far; when called,
69    /// it will be in one of the following situations:
70    ///
71    /// 1. The buffer contains less than a full frame.
72    /// 2. The buffer contains exactly a full frame.
73    /// 3. The buffer contains more than a full frame.
74    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
75        // Keep a reference of the buffer to not advance the main buffer itself (yet).
76        let mut bytes: &[u8] = src.as_ref();
77        let starting = bytes.len();
78
79        // Attempt decoding the buffer and remember how many bytes we've advanced it doing that.
80        //
81        // This will succeed in case 2. and 3.
82        let result: Result<Self::Item, _> = decode_cbor(&mut bytes);
83        let ending = bytes.len();
84
85        match result {
86            Ok(item) => {
87                // We've successfully read one full frame from the buffer. We're finally
88                // advancing it for the next decode iteration and yield the resulting data item to
89                // the stream.
90                src.advance(starting - ending);
91                Ok(Some(item))
92            }
93            // Note that the buffer is not further advanced in case of an error.
94            Err(error) => match error {
95                DecodeError::Io(err) => {
96                    if err.kind() == std::io::ErrorKind::UnexpectedEof {
97                        // EOF errors indicate that our buffer doesn't contain enough data to
98                        // decode a whole CBOR frame. We're yielding no data item and re-try
99                        // decoding in the next iteration.
100                        //
101                        // This is handling case 1.
102                        Ok(None)
103                    } else {
104                        // An I/O error during decoding usually indicates something wrong with our
105                        // system (lack of system memory etc.).
106                        Err(CborCodecError::IO(format!(
107                            "CBOR codec failed decoding message due to i/o error, {err}"
108                        )))
109                    }
110                }
111                err => Err(CborCodecError::Decode(err)),
112            },
113        }
114    }
115}
116
117/// Returns a reader for your data type, automatically decoding CBOR byte-streams and handling the
118/// message framing.
119///
120/// This can be used in various sync protocol implementations where we need to receive data via a
121/// wire protocol between two peers.
122///
123/// This is a convenience method if you want to use CBOR encoding and serde to handle your wire
124/// protocol message encoding and framing without implementing it yourself. If you're interested in
125/// your own approach you can either implement your own `FramedRead` or `Sink`.
126pub fn into_cbor_stream<M, T>(
127    rx: T,
128) -> impl Stream<Item = Result<M, CborCodecError>> + Unpin + use<M, T>
129where
130    M: for<'de> Deserialize<'de> + Serialize + 'static,
131    T: AsyncRead + Unpin + 'static,
132{
133    FramedRead::new(rx, CborCodec::<M>::new())
134}
135
136/// Returns a writer for your data type, automatically encoding it as CBOR for a framed
137/// byte-stream.
138///
139/// This can be used in various sync protocol implementations where we need to send data via a wire
140/// protocol between two peers.
141///
142/// This is a convenience method if you want to use CBOR encoding and serde to handle your wire
143/// protocol message encoding and framing without implementing it yourself. If you're interested in
144/// your own approach you can either implement your own `FramedWrite` or `Stream`.
145pub fn into_cbor_sink<M, T>(tx: T) -> impl Sink<M, Error = CborCodecError>
146where
147    M: for<'de> Deserialize<'de> + Serialize + 'static,
148    T: AsyncWrite + Unpin + 'static,
149{
150    FramedWrite::new(tx, CborCodec::<M>::new())
151}
152
153/// Errors which can occur while decoding or encoding streams of cbor bytes.
154#[derive(Debug, Error)]
155pub enum CborCodecError {
156    #[error(transparent)]
157    Decode(#[from] DecodeError),
158
159    #[error(transparent)]
160    Encode(#[from] EncodeError),
161
162    #[error("{0}")]
163    IO(String),
164
165    #[error("{0}")]
166    BrokenPipe(String),
167}
168
169/// Converts critical I/O error (which occurs during codec stream handling) into [`CborCodecError`].
170///
171/// This is usually a critical system failure indicating an implementation bug or lacking resources
172/// on the user's machine.
173///
174/// See `Encoder` or `Decoder` `Error` trait type in tokio's codec for more information:
175/// <https://docs.rs/tokio-util/latest/tokio_util/codec/trait.Decoder.html#associatedtype.Error>
176impl From<std::io::Error> for CborCodecError {
177    fn from(err: std::io::Error) -> Self {
178        match err.kind() {
179            // Broken pipes usually indicate that the remote peer closed the connection
180            // unexpectedly, this is why we're not treating it as a critical error but as
181            // "unexpected behaviour" instead.
182            std::io::ErrorKind::BrokenPipe => Self::BrokenPipe("broken pipe".into()),
183            _ => Self::IO(format!("internal i/o stream error {err}")),
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use futures_util::{FutureExt, SinkExt, StreamExt};
191    use p2panda_core::{Body, Hash, Header, PrivateKey};
192    use tokio::io::AsyncWriteExt;
193    use tokio_util::codec::FramedRead;
194
195    use crate::timestamp::Timestamp;
196
197    use super::{CborCodec, into_cbor_sink, into_cbor_stream};
198
199    #[tokio::test]
200    async fn decoding_exactly_one_frame() {
201        let (mut tx, rx) = tokio::io::duplex(64);
202        let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
203
204        // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
205        // Hexadecimal representation = 65
206        // Decimal representation = 101
207        tx.write_all(&[101]).await.unwrap();
208
209        // CBOR body, the actual string.
210        tx.write_all("hello".as_bytes()).await.unwrap();
211
212        let message = stream.next().await;
213        assert_eq!(message.unwrap().unwrap(), "hello".to_string());
214    }
215
216    #[tokio::test]
217    async fn decoding_more_than_one_frame() {
218        let (mut tx, rx) = tokio::io::duplex(64);
219        let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
220
221        // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
222        // Hexadecimal representation = 65
223        // Decimal representation = 101
224        tx.write_all(&[101]).await.unwrap();
225
226        // CBOR body, the actual string.
227        tx.write_all("hello".as_bytes()).await.unwrap();
228
229        // Another CBOR header (frame) for another message (length of 9).
230        // Hexadecimal representation = 69
231        // Decimal representation = 105
232        tx.write_all(&[105]).await.unwrap();
233        tx.write_all("aquariums".as_bytes()).await.unwrap();
234
235        let message = stream.next().await;
236        assert_eq!(message.unwrap().unwrap(), "hello".to_string());
237
238        let message = stream.next().await;
239        assert_eq!(message.unwrap().unwrap(), "aquariums".to_string());
240    }
241
242    #[tokio::test]
243    async fn decoding_incomplete_frame() {
244        let (mut tx, rx) = tokio::io::duplex(64);
245        let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
246
247        // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
248        // Hexadecimal representation = 65
249        // Decimal representation = 101
250        tx.write_all(&[101]).await.unwrap();
251
252        // Attempt to decode an incomplete CBOR frame, the decoder should not yield anything.
253        let message = stream.next().now_or_never();
254        assert!(message.is_none());
255
256        // Complete the CBOR data item in the buffer.
257        tx.write_all("hello".as_bytes()).await.unwrap();
258
259        let message = stream.next().await;
260        assert_eq!(message.unwrap().unwrap(), "hello".to_string());
261    }
262
263    #[tokio::test]
264    async fn operations_stream() {
265        type Payload = (Header<()>, Option<Body>);
266
267        fn create_operation(
268            private_key: &PrivateKey,
269            body: &[u8],
270            seq_num: u64,
271            backlink: Option<Hash>,
272        ) -> Payload {
273            let body = Body::from(body);
274            let mut header = Header {
275                version: 1,
276                public_key: private_key.public_key(),
277                signature: None,
278                payload_size: body.size(),
279                payload_hash: Some(body.hash()),
280                timestamp: Timestamp::now().into(),
281                seq_num,
282                backlink,
283                previous: vec![],
284                extensions: (),
285            };
286            header.sign(private_key);
287            (header, Some(body))
288        }
289
290        let (tx_inner, rx_inner) = tokio::io::duplex(64);
291
292        let mut tx = into_cbor_sink::<Payload, _>(tx_inner);
293        let mut rx = into_cbor_stream::<Payload, _>(rx_inner);
294
295        // Create 100 operations, encode them as CBOR and send bytes to receiver.
296        tokio::task::spawn(async move {
297            let private_key = PrivateKey::new();
298
299            let mut seq_num = 0;
300            let mut backlink = None;
301
302            for _ in 0..100 {
303                let (header, body) =
304                    create_operation(&private_key, b"boom boom boom", seq_num, backlink);
305                seq_num += 1;
306                backlink = Some(header.hash());
307
308                tx.send((header, body)).await.unwrap();
309            }
310        });
311
312        // Receiver writes bytes into buffer, attempts decoding as CBOR and returns header/body
313        // tuple 100 times.
314        let mut i = 1;
315        loop {
316            if let Some(message) = rx.next().await {
317                if let Err(err) = message {
318                    panic!("{err}");
319                }
320
321                i += 1;
322
323                if i == 100 {
324                    break;
325                }
326            }
327        }
328    }
329}