Skip to main content

pipenet/
lib.rs

1//! A non blocking tcp stream wrapper.
2//!
3//! This module is useful when wanting to use the non blocking feature of a
4//! socket, but without having to depend on async.
5//!
6//! The [`NonBlockStream`] can be obtained from a [`TcpStream`] just with an
7//! into() call.
8//!
9//! Reads [`NonBlockStream::read`] and writes [`NonBlockStream::write`] are
10//! called whenever the user has time to check for messages or needs to write,
11//! and regardless of the nature of the caller, the IO operations will happen
12//! in the background in a separate thread maintained by the [`NonBlockStream`]
13//! struct.
14
15mod packs;
16pub use packs::Packs;
17
18mod looper;
19
20mod metrics;
21
22#[cfg(test)]
23mod test;
24
25use std::{
26    io::{ErrorKind, Read, Write},
27    net::{SocketAddr, TcpStream},
28    sync::{
29        Arc, Mutex,
30        mpsc::{Receiver, Sender, TryRecvError, channel},
31    },
32    thread::JoinHandle,
33};
34
35use crate::metrics::Metrics;
36
37/// This is the configuration that is used to determine the versioning of the
38/// messages.
39///
40/// When this is the default, messages emitted will have version 1 and only
41/// accept version 1.
42///
43/// If only one version is converted from from u16 then that is the only
44/// version emitted and supported.
45///
46/// If more options are given to [Versions::new] then the first will be
47/// the version being emitted, the other two the minimum and maximum version
48/// that will be accepted on read (inclusive).
49///
50/// ```
51/// use pipenet::Versions;
52///
53/// let v: Versions = Versions::default();
54/// let v: Versions = 1.into();
55/// let v: Versions = Versions::new(3, 1, 3);
56/// ```
57#[derive(Clone, Copy, Debug, Hash, PartialEq)]
58pub struct Versions {
59    cur: u16,
60    min: u16,
61    max: u16,
62}
63
64impl Default for Versions {
65    fn default() -> Self {
66        Self {
67            cur: 1,
68            min: 1,
69            max: 1,
70        }
71    }
72}
73
74impl From<u16> for Versions {
75    fn from(value: u16) -> Self {
76        Self {
77            cur: value,
78            min: value,
79            max: value,
80        }
81    }
82}
83
84impl Versions {
85    /// Creates a new version compatibility object.
86    /// - `cur`: the current version that will be written on a message
87    /// - `min`: the minimum version accepted from reading, discard otherwise.
88    /// - `max`: the maximum version accepted from writing, discard otherwise.
89    pub fn new(cur: u16, min: u16, max: u16) -> Self {
90        Self { cur, min, max }
91    }
92}
93
94/// A non blocking wrapper for a [`TcpStream`].
95///
96/// Supports [`From<TcpStream>`] so it is built throgh [`Into::into`]. The
97/// original stream will be consumed by this process as this instance will now
98/// own the stream.
99///
100/// [`NonBlockStream`] maintains its own IO thread in the background which will
101/// be terminated once this instance gets dropped, or if the underlying socket
102/// gets closed by returning the original [`std::io::Error`].
103///
104/// Upon any error returned to the caller of [`NonBlockStream::read`] or
105/// [`NonBlockStream::write`], the caller will have to consider the stream to
106/// be broken and it is required to drop this instance: the background thread
107/// will have been terminated at that point and this [`NonBlockStream`] is now
108/// unusable and no other calls to read or write should be made.
109///
110/// Since it is based on [`TcpStream`], it is sequential and can handle only a
111/// single stream. The [`NonBlockStream`] is in a way dual channel, but through
112/// means of interleaving read/write buffering. The buffer is changing and it's
113/// always the size of the next message being written/read.
114///
115/// The IO thread will keep processing the stream in the background, but it
116/// will also sleep (using [`mio::Poll`]) and wake up when either read or write
117/// operations are possible again. Whether that will happen depends on the size
118/// of the internal buffers of the [`TcpStream`] being passed from creation.
119///
120/// The [`TcpStream`] is kept as it is when received in its configuration, with
121/// one exception of making it non blocking. During initialization, a call to
122/// [`TcpStream::set_nonblocking`] is made and if not successful, it will
123/// panic. Make sure to pass in a [`TcpStream`] that is either capable of being
124/// set to non blocking, or better yet, set it before converting it onto a
125/// [`NonBlockStream`].
126///
127/// It is expected that the [`TcpStream`] being passsed on creation is already
128/// in the connected state.
129///
130/// The header is 10 bytes and is sent per every message.
131/// Take that into consideration for how big the message type should be and if
132/// it is advantaging to use this method for transmission.
133///
134/// Reads and writes will ingest or return boxed instances of the message.
135///
136/// In order to write to the stream use the [`NonBlockStream::write`]. This
137/// will add the message to an internal channel (mpsc).
138/// *The call to write does not block.*
139///
140/// To check if there is a message available call [`NonBlockStream::read`].
141/// This will check another internal channel if some message is ready. If none
142/// is, then the call to read will return [`None`].
143/// *The call to read does not block.*
144///
145/// The [`NonBlockStream`] can be cloned and is [`Send`] and [`Sync`] so it can
146/// be used across frameworks that require it.
147///
148/// ```no_run
149/// use std::net::{TcpStream, SocketAddr};
150/// use pipenet::NonBlockStream;
151///
152/// let stream = TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 9999))).unwrap();
153/// let mut nbstream: NonBlockStream = stream.into();
154///
155/// // A simple, one time, echo example
156/// if let Some(msg) = nbstream.read().unwrap() {
157///     nbstream.write(vec![1,2,3]).unwrap();
158/// }
159/// ```
160///
161/// Versioning is supported. Marks the current version, and discards versions
162/// that are outside the min/max version range.
163///
164/// To add more wrappers on the message, such as encryption or compresssion,
165/// use the constructor with the encapsulations.
166///
167/// To use those methods the features "compression" and/or "encryption" will be
168/// required.
169///
170/// ```ignore
171/// use std::net::{TcpStream, SocketAddr};
172/// use pipenet::NonBlockStream;
173/// use pipenet::Versions;
174/// use pipenet::Packs;
175///
176/// let msg = Msg { data:vec![] };
177/// let key = &[0u8; 32];
178/// let stream = TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 9999))).unwrap();
179/// let mut nbstream: NonBlockStream<Msg> = NonBlockStream::from_version_encapsulations(
180///     Versions::new(2, 1, 3), // Current version 2, supports from 1 to 3
181///     Packs::default()
182///         .compress()
183///         .encrypt(key),
184///     stream);
185///
186/// // A simple, one time, echo example
187/// if let Some(msg) = nbstream.read().unwrap() {
188///     nbstream.write(vec![1,2,3]).unwrap();
189/// }
190/// ```
191#[derive(Clone)]
192pub struct NonBlockStream {
193    rx_reader: Arc<Mutex<Receiver<Vec<u8>>>>,
194    tx_writer: Sender<Vec<u8>>,
195    rx_err: Arc<Mutex<Receiver<std::io::Error>>>,
196    local_addr: SocketAddr,
197    remote_addr: SocketAddr,
198    _handle: Arc<JoinHandle<()>>,
199    metrics: Metrics,
200}
201
202impl From<TcpStream> for NonBlockStream {
203    fn from(stream: TcpStream) -> Self {
204        NonBlockStream::from_versions(Versions::default(), stream)
205    }
206}
207
208impl NonBlockStream {
209    pub fn from_version_packs(v: Versions, packs: Packs, stream: TcpStream) -> Self {
210        let local_addr = stream
211            .local_addr()
212            .expect("Could not obtain local_addr from stream");
213        let remote_addr = stream
214            .peer_addr()
215            .expect("Could not obtain peer_addr from stream");
216        stream
217            .set_nonblocking(true)
218            .expect("Could not set socket to nonblocking. It is required for communication.");
219
220        let (tx_reader, rx_reader) = channel::<Vec<u8>>();
221        let (tx_writer, rx_writer) = channel::<Vec<u8>>();
222        let (tx_err, rx_err) = channel::<std::io::Error>();
223
224        let (metrics, metrics_tx) = Metrics::new();
225
226        // The looper consumes the TcpStream.
227        let looper =
228            looper::StreamLooper::new(v, packs, stream, tx_reader, rx_writer, tx_err, metrics_tx);
229        let handle = std::thread::spawn(move || {
230            looper.stream_loop();
231        });
232
233        Self {
234            rx_reader: Arc::new(Mutex::new(rx_reader)),
235            tx_writer,
236            rx_err: Arc::new(Mutex::new(rx_err)),
237            local_addr,
238            remote_addr,
239            _handle: Arc::new(handle),
240            metrics,
241        }
242    }
243
244    pub fn from_versions(v: Versions, stream: TcpStream) -> Self {
245        Self::from_version_packs(v, Default::default(), stream)
246    }
247
248    /// The address of the local tcp stream.
249    pub fn local_addr(&self) -> SocketAddr {
250        self.local_addr
251    }
252
253    /// The address of the remote end of the tcp stream.
254    pub fn remote_addr(&self) -> SocketAddr {
255        self.remote_addr
256    }
257
258    /// Queue a new message for write.
259    pub fn write(&mut self, msg: Vec<u8>) -> Result<(), std::io::Error> {
260        self.trap_fault()?;
261        self.trap_write(msg)
262    }
263
264    /// Check if there is a message available to read and return it.
265    pub fn read(&mut self) -> Result<Option<Vec<u8>>, std::io::Error> {
266        self.trap_fault()?;
267        self.trap_recv()
268    }
269
270    pub fn total_read(&self) -> usize {
271        self.metrics.read()
272    }
273
274    pub fn total_sent(&self) -> usize {
275        self.metrics.sent()
276    }
277
278    fn trap_write(&mut self, msg: Vec<u8>) -> Result<(), std::io::Error> {
279        self.tx_writer
280            .send(msg)
281            .map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, e))
282    }
283
284    fn trap_recv(&mut self) -> Result<Option<Vec<u8>>, std::io::Error> {
285        let op = self.rx_reader.lock().unwrap().try_recv();
286        match op {
287            Ok(msg) => Ok(Some(msg)),
288            Err(e) => match e {
289                TryRecvError::Empty => Ok(None),
290                TryRecvError::Disconnected => {
291                    Err(std::io::Error::new(ErrorKind::ConnectionAborted, e))
292                }
293            },
294        }
295    }
296
297    fn trap_fault(&mut self) -> Result<(), std::io::Error> {
298        let op = self.rx_err.lock().unwrap().try_recv();
299        match op {
300            Ok(f) => Err(f),
301            Err(e) => match e {
302                TryRecvError::Empty => Ok(()),
303                TryRecvError::Disconnected => {
304                    Err(std::io::Error::new(ErrorKind::ConnectionAborted, e))
305                }
306            },
307        }
308    }
309}