tokio_bincode/
lib.rs

1//! Tokio codec for use with bincode
2//!
3//! This crate provides a `bincode` based codec that can be used with
4//! tokio's `Framed`, `FramedRead`, and `FramedWrite`.
5//!
6//! # Example
7//!
8//! ```
9//! # use futures::{Stream, Sink};
10//! # use tokio::io::{AsyncRead, AsyncWrite};
11//! # use tokio_codec::Framed;
12//! # use serde::{Serialize, Deserialize};
13//! # use tokio_bincode::BinCodec;
14//! # use serde_derive::{Serialize, Deserialize};
15//! # fn sd<'a>(transport: impl AsyncRead + AsyncWrite) {
16//! #[derive(Serialize, Deserialize)]
17//! struct MyProtocol;
18//!
19//! // Create the codec based on your custom protocol
20//! let codec = BinCodec::<MyProtocol>::new();
21//!
22//! // Frame the transport with the codec to produce a stream/sink
23//! let (sink, stream) = Framed::new(transport, codec).split();
24//! # }
25//! ```
26
27#![deny(missing_docs, missing_debug_implementations)]
28
29use bincode::Config;
30use bytes::{BufMut, BytesMut};
31use serde::{Deserialize, Serialize};
32use std::fmt;
33use std::io::{self, Read};
34use std::marker::PhantomData;
35use tokio_codec::{Decoder, Encoder};
36
37/// Bincode based codec for use with `tokio-codec`
38pub struct BinCodec<T> {
39    config: Config,
40    _pd: PhantomData<T>,
41}
42
43impl<T> BinCodec<T> {
44    /// Provides a bincode based codec
45    pub fn new() -> Self {
46        let config = bincode::config();
47        BinCodec::with_config(config)
48    }
49
50    /// Provides a bincode based codec from the bincode config
51    pub fn with_config(config: Config) -> Self {
52        BinCodec {
53            config,
54            _pd: PhantomData,
55        }
56    }
57}
58
59impl<T> Decoder for BinCodec<T>
60where
61    for<'de> T: Deserialize<'de>,
62{
63    type Item = T;
64    type Error = bincode::Error;
65
66    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
67        if !buf.is_empty() {
68            let mut reader = Reader::new(&buf[..]);
69            let message = self.config.deserialize_from(&mut reader)?;
70            buf.split_to(reader.amount());
71            Ok(Some(message))
72        } else {
73            Ok(None)
74        }
75    }
76}
77
78impl<T> Encoder for BinCodec<T>
79where
80    T: Serialize,
81{
82    type Item = T;
83    type Error = bincode::Error;
84
85    fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
86        let size = self.config.serialized_size(&item)?;
87        buf.reserve(size as usize);
88        let message = self.config.serialize(&item)?;
89        buf.put(&message[..]);
90        Ok(())
91    }
92}
93
94impl<T> fmt::Debug for BinCodec<T> {
95    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96        f.debug_struct("BinCodec").finish()
97    }
98}
99
100#[derive(Debug)]
101struct Reader<'buf> {
102    buf: &'buf [u8],
103    amount: usize,
104}
105
106impl<'buf> Reader<'buf> {
107    pub fn new(buf: &'buf [u8]) -> Self {
108        Reader { buf, amount: 0 }
109    }
110
111    pub fn amount(&self) -> usize {
112        self.amount
113    }
114}
115
116impl<'buf, 'a> Read for &'a mut Reader<'buf> {
117    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
118        let bytes_read = self.buf.read(buf)?;
119        self.amount += bytes_read;
120        Ok(bytes_read)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use futures::{Future, Sink, Stream};
128    use serde_derive::{Deserialize, Serialize};
129    use std::net::SocketAddr;
130    use tokio::{
131        codec::Framed,
132        net::{TcpListener, TcpStream},
133        runtime::current_thread,
134    };
135
136    #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
137    enum Mock {
138        One,
139        Two,
140    }
141
142    #[test]
143    fn it_works() {
144        let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), 15151);
145        let echo = TcpListener::bind(&addr).unwrap();
146
147        let jh = std::thread::spawn(move || {
148            current_thread::run(
149                echo.incoming()
150                    .map_err(bincode::Error::from)
151                    .take(1)
152                    .for_each(|stream| {
153                        let (w, r) = Framed::new(stream, BinCodec::<Mock>::new()).split();
154                        r.forward(w).map(|_| ())
155                    })
156                    .map_err(|_| ()),
157            )
158        });
159
160        let client = TcpStream::connect(&addr).wait().unwrap();
161        let client = Framed::new(client, BinCodec::<Mock>::new());
162
163        let client = client.send(Mock::One).wait().unwrap();
164
165        let (got, client) = match client.into_future().wait() {
166            Ok(x) => x,
167            Err((e, _)) => panic!(e),
168        };
169
170        assert_eq!(got, Some(Mock::One));
171
172        let client = client.send(Mock::Two).wait().unwrap();
173
174        let (got2, client) = match client.into_future().wait() {
175            Ok(x) => x,
176            Err((e, _)) => panic!(e),
177        };
178
179        assert_eq!(got2, Some(Mock::Two));
180
181        drop(client);
182        jh.join().unwrap();
183    }
184}