p2panda_sync/cbor.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Utility methods to encode or decode wire protocol messages in [CBOR] format.
4//!
5//! [CBOR]: https://cbor.io/
6use std::marker::PhantomData;
7
8use futures::{AsyncRead, AsyncWrite, Sink, Stream};
9use p2panda_core::cbor::{DecodeError, decode_cbor, encode_cbor};
10use serde::de::DeserializeOwned;
11use serde::{Deserialize, Serialize};
12use tokio_util::bytes::{Buf, BytesMut};
13use tokio_util::codec::{Decoder, Encoder};
14use tokio_util::codec::{FramedRead, FramedWrite};
15use tokio_util::compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt};
16
17use crate::SyncError;
18
19/// Implementation of the tokio codec traits to encode- and decode CBOR data as a stream.
20///
21/// CBOR allows message framing based on initial "headers" for each "data item", which indicate the
22/// type of data and the expected "body" length to be followed. A stream-based decoder can attempt
23/// parsing these headers and then reason about if it has enough information to proceed.
24///
25/// Read more on CBOR in streaming applications here:
26/// <https://www.rfc-editor.org/rfc/rfc8949.html#section-5.1>
27#[derive(Clone, Debug)]
28pub struct CborCodec<T> {
29 _phantom: PhantomData<T>,
30}
31
32impl<M> CborCodec<M> {
33 pub fn new() -> Self {
34 CborCodec {
35 _phantom: PhantomData {},
36 }
37 }
38}
39
40impl<M> Default for CborCodec<M> {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl<T> Encoder<T> for CborCodec<T>
47where
48 T: Serialize,
49{
50 type Error = SyncError;
51
52 /// Encodes a serializable item into CBOR bytes and adds them to the buffer.
53 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
54 let bytes = encode_cbor(&item).map_err(|err| {
55 // When we've failed encoding our _own_ messages something seriously went wrong.
56 SyncError::Critical(format!("CBOR codec failed encoding message, {err}"))
57 })?;
58 // Append the encoded CBOR bytes to the buffer instead of replacing it, we might already
59 // have previously encoded items in it.
60 dst.extend_from_slice(&bytes);
61 Ok(())
62 }
63}
64
65impl<T> Decoder for CborCodec<T>
66where
67 T: Serialize + DeserializeOwned,
68{
69 type Item = T;
70 type Error = SyncError;
71
72 /// CBOR decoder method taking as an argument the bytes that have been read so far; when called,
73 /// it will be in one of the following situations:
74 ///
75 /// 1. The buffer contains less than a full frame.
76 /// 2. The buffer contains exactly a full frame.
77 /// 3. The buffer contains more than a full frame.
78 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
79 // Keep a reference of the buffer to not advance the main buffer itself (yet).
80 let mut bytes: &[u8] = src.as_ref();
81 let starting = bytes.len();
82
83 // Attempt decoding the buffer and remember how many bytes we've advanced it doing that.
84 //
85 // This will succeed in case 2. and 3.
86 let result: Result<Self::Item, _> = decode_cbor(&mut bytes);
87 let ending = bytes.len();
88
89 match result {
90 Ok(item) => {
91 // We've successfully read one full frame from the buffer. We're finally
92 // advancing it for the next decode iteration and yield the resulting data item to
93 // the stream.
94 src.advance(starting - ending);
95 Ok(Some(item))
96 }
97 // Note that the buffer is not further advanced in case of an error.
98 Err(ref error) => match error {
99 DecodeError::Io(err) => {
100 if err.kind() == std::io::ErrorKind::UnexpectedEof {
101 // EOF errors indicate that our buffer doesn't contain enough data to
102 // decode a whole CBOR frame. We're yielding no data item and re-try
103 // decoding in the next iteration.
104 //
105 // This is handling case 1.
106 Ok(None)
107 } else {
108 // An I/O error during decoding usually indicates something wrong with our
109 // system (lack of system memory etc.).
110 Err(SyncError::Critical(format!(
111 "CBOR codec failed decoding message due to i/o error, {err}"
112 )))
113 }
114 }
115 err => Err(SyncError::InvalidEncoding(err.to_string())),
116 },
117 }
118 }
119}
120
121/// Returns a reader for your data type, automatically decoding CBOR byte-streams and handling the
122/// message framing.
123///
124/// This can be used in various sync protocol implementations where we need to receive data via a
125/// wire protocol between two peers.
126///
127/// This is a convenience method if you want to use CBOR encoding and serde to handle your wire
128/// protocol message encoding and framing without implementing it yourself. If you're interested in
129/// your own approach you can either implement your own `FramedRead` or `Sink`.
130pub fn into_cbor_stream<'a, M>(
131 rx: Box<&'a mut (dyn AsyncRead + Send + Unpin)>,
132) -> impl Stream<Item = Result<M, SyncError>> + Send + Unpin + 'a
133where
134 M: for<'de> Deserialize<'de> + Serialize + Send + 'a,
135{
136 FramedRead::new(rx.compat(), CborCodec::<M>::new())
137}
138
139/// Returns a writer for your data type, automatically encoding it as CBOR for a framed
140/// byte-stream.
141///
142/// This can be used in various sync protocol implementations where we need to send data via a wire
143/// protocol between two peers.
144///
145/// This is a convenience method if you want to use CBOR encoding and serde to handle your wire
146/// protocol message encoding and framing without implementing it yourself. If you're interested in
147/// your own approach you can either implement your own `FramedWrite` or `Stream`.
148pub fn into_cbor_sink<'a, M>(
149 tx: Box<&'a mut (dyn AsyncWrite + Send + Unpin)>,
150) -> impl Sink<M, Error = SyncError> + Send + Unpin + 'a
151where
152 M: for<'de> Deserialize<'de> + Serialize + Send + 'a,
153{
154 FramedWrite::new(tx.compat_write(), CborCodec::<M>::new())
155}
156
157#[cfg(test)]
158mod tests {
159 use futures::FutureExt;
160 use tokio::io::AsyncWriteExt;
161 use tokio_stream::StreamExt;
162 use tokio_util::codec::FramedRead;
163
164 use super::CborCodec;
165
166 #[tokio::test]
167 async fn decoding_exactly_one_frame() {
168 let (mut tx, rx) = tokio::io::duplex(64);
169 let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
170
171 // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
172 // Hexadecimal representation = 65
173 // Decimal representation = 101
174 tx.write_all(&[101]).await.unwrap();
175
176 // CBOR body, the actual string.
177 tx.write_all("hello".as_bytes()).await.unwrap();
178
179 let message = stream.next().await;
180 assert_eq!(message, Some(Ok("hello".into())));
181 }
182
183 #[tokio::test]
184 async fn decoding_more_than_one_frame() {
185 let (mut tx, rx) = tokio::io::duplex(64);
186 let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
187
188 // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
189 // Hexadecimal representation = 65
190 // Decimal representation = 101
191 tx.write_all(&[101]).await.unwrap();
192
193 // CBOR body, the actual string.
194 tx.write_all("hello".as_bytes()).await.unwrap();
195
196 // Another CBOR header (frame) for another message (length of 9).
197 // Hexadecimal representation = 69
198 // Decimal representation = 105
199 tx.write_all(&[105]).await.unwrap();
200 tx.write_all("aquariums".as_bytes()).await.unwrap();
201
202 let message = stream.next().await;
203 assert_eq!(message, Some(Ok("hello".into())));
204
205 let message = stream.next().await;
206 assert_eq!(message, Some(Ok("aquariums".into())));
207 }
208
209 #[tokio::test]
210 async fn decoding_incomplete_frame() {
211 let (mut tx, rx) = tokio::io::duplex(64);
212 let mut stream = FramedRead::new(rx, CborCodec::<String>::new());
213
214 // CBOR header indicating that a string (6) is followed with the length of 5 bytes.
215 // Hexadecimal representation = 65
216 // Decimal representation = 101
217 tx.write_all(&[101]).await.unwrap();
218
219 // Attempt to decode an incomplete CBOR frame, the decoder should not yield anything.
220 let message = stream.next().now_or_never();
221 assert_eq!(message, None);
222
223 // Complete the CBOR data item in the buffer.
224 tx.write_all("hello".as_bytes()).await.unwrap();
225
226 let message = stream.next().await;
227 assert_eq!(message, Some(Ok("hello".into())));
228 }
229}