pipenet 0.1.5

Non blocking tcp stream wrapper using channels
Documentation
//! A non blocking tcp stream wrapper.
//!
//! This module is useful when wanting to use the non blocking feature of a
//! socket, but without having to depend on async.
//!
//! The [`NonBlockStream`] can be obtained from a [`TcpStream`] just with an
//! into() call.
//!
//! Reads [`NonBlockStream::read`] and writes [`NonBlockStream::write`] are
//! called whenever the user has time to check for messages or needs to write,
//! and regardless of the nature of the caller, the IO operations will happen
//! in the background in a separate thread maintained by the [`NonBlockStream`]
//! struct.

mod packs;
pub use packs::Packs;

mod looper;

#[cfg(test)]
mod test;

use std::{
    io::{ErrorKind, Read, Write},
    net::{SocketAddr, TcpStream},
    sync::{
        Arc, Mutex,
        mpsc::{Receiver, Sender, TryRecvError, channel},
    },
    thread::JoinHandle,
};

use serde::{Deserialize, Serialize, de::DeserializeOwned};

/// The type of the message being send/received by the stream.
///
/// Needs to be [`Serialize`] and [`Deserialize`], plus for being passed around
/// the processing thread, [`Send`] + [`Sync`].
pub trait Message: Serialize + DeserializeOwned + Send + Sync + 'static {}
impl<T> Message for T where for<'a> T: Serialize + Deserialize<'a> + Send + Sync + 'static {}

/// This is the configuration that is used to determine the versioning of the
/// messages.
///
/// When this is the default, messages emitted will have version 1 and only
/// accept version 1.
///
/// If only one version is converted from from u16 then that is the only
/// version emitted and supported.
///
/// If more options are given to [Versions::new] then the first will be
/// the version being emitted, the other two the minimum and maximum version
/// that will be accepted on read (inclusive).
///
/// ```
/// use pipenet::Versions;
///
/// let v: Versions = Versions::default();
/// let v: Versions = 1.into();
/// let v: Versions = Versions::new(3, 1, 3);
/// ```
#[derive(Clone, Copy, Debug, Hash, PartialEq)]
pub struct Versions {
    cur: u16,
    min: u16,
    max: u16,
}

impl Default for Versions {
    fn default() -> Self {
        Self {
            cur: 1,
            min: 1,
            max: 1,
        }
    }
}

impl From<u16> for Versions {
    fn from(value: u16) -> Self {
        Self {
            cur: value,
            min: value,
            max: value,
        }
    }
}

impl Versions {
    /// Creates a new version compatibility object.
    /// - `cur`: the current version that will be written on a message
    /// - `min`: the minimum version accepted from reading, discard otherwise.
    /// - `max`: the maximum version accepted from writing, discard otherwise.
    pub fn new(cur: u16, min: u16, max: u16) -> Self {
        Self { cur, min, max }
    }
}

/// A non blocking wrapper for a [`TcpStream`].
///
/// Supports [`From<TcpStream>`] so it is built throgh [`Into::into`]. The
/// original stream will be consumed by this process as this instance will now
/// own the stream.
///
/// [`NonBlockStream`] maintains its own IO thread in the background which will
/// be terminated once this instance gets dropped, or if the underlying socket
/// gets closed by returning the original [`std::io::Error`].
///
/// Upon any error returned to the caller of [`NonBlockStream::read`] or
/// [`NonBlockStream::write`], the caller will have to consider the stream to
/// be broken and it is required to drop this instance: the background thread
/// will have been terminated at that point and this [`NonBlockStream`] is now
/// unusable and no other calls to read or write should be made.
///
/// Since it is based on [`TcpStream`], it is sequential and can handle only a
/// single stream. The [`NonBlockStream`] is in a way dual channel, but through
/// means of interleaving read/write buffering. The buffer is changing and it's
/// always the size of the next message being written/read.
///
/// The IO thread will keep processing the stream in the background, but it
/// will also sleep (using [`mio::Poll`]) and wake up when either read or write
/// operations are possible again. Whether that will happen depends on the size
/// of the internal buffers of the [`TcpStream`] being passed from creation.
///
/// The [`TcpStream`] is kept as it is when received in its configuration, with
/// one exception of making it non blocking. During initialization, a call to
/// [`TcpStream::set_nonblocking`] is made and if not successful, it will
/// panic. Make sure to pass in a [`TcpStream`] that is either capable of being
/// set to non blocking, or better yet, set it before converting it onto a
/// [`NonBlockStream`].
///
/// It is expected that the [`TcpStream`] being passsed on creation is already
/// in the connected state.
///
/// This type is generic over the message being passed in the stream and it is
/// message 'aware'. This means that it will send its own headers to determine
/// size and send and parse back the message until the proper chunks are
/// available.
///
/// The header is 10 bytes and is sent per every message.
/// Take that into consideration for how big the message type should be and if
/// it is advantaging to use this method for transmission.
///
/// Reads and writes will ingest or return boxed instances of the message.
/// The [`Message`] type needs to be serializable + deserializable with
/// [`serde`].
///
/// In order to write to the stream use the [`NonBlockStream::write`]. This
/// will add the message to an internal channel (mpsc).
/// *The call to write does not block.*
///
/// To check if there is a message available call [`NonBlockStream::read`].
/// This will check another internal channel if some message is ready. If none
/// is, then the call to read will return [`None`].
/// *The call to read does not block.*
///
/// The [`NonBlockStream`] can be cloned and is [`Send`] and [`Sync`] so it can
/// be used across frameworks that require it.
///
/// ```no_run
/// use std::net::{TcpStream, SocketAddr};
/// use serde::{Serialize, Deserialize};
/// use pipenet::NonBlockStream;
///
/// #[derive(Serialize, Deserialize)]
/// struct MyType {
///     x: i32,
///     y: i32,
/// }
///
/// let stream = TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 9999))).unwrap();
/// let mut nbstream: NonBlockStream<MyType> = stream.into();
///
/// // A simple, one time, echo example
/// if let Some(msg) = nbstream.read().unwrap() {
///     nbstream.write(msg).unwrap();
/// }
/// ```
///
/// Versioning is supported. Marks the current version, and discards versions
/// that are outside the min/max version range.
///
/// To add more wrappers on the message, such as encryption or compresssion,
/// use the constructor with the encapsulations.
///
/// To use those methods the features "compression" and/or "encryption" will be
/// required.
///
/// ```ignore
/// use std::net::{TcpStream, SocketAddr};
/// use serde::{Serialize, Deserialize};
/// use pipenet::NonBlockStream;
/// use pipenet::Versions;
/// use pipenet::Packs;
///
/// #[derive(Serialize, Deserialize)]
/// struct MyType {
///     x: i32,
///     y: i32,
/// }
///
/// let key = &[0u8; 32];
/// let stream = TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 9999))).unwrap();
/// let mut nbstream: NonBlockStream<MyType> = NonBlockStream::from_version_encapsulations(
///     Versions::new(2, 1, 3), // Current version 2, supports from 1 to 3
///     Packs::default()
///         .compress()
///         .encrypt(key),
///     stream);
///
/// // A simple, one time, echo example
/// if let Some(msg) = nbstream.read().unwrap() {
///     nbstream.write(msg).unwrap();
/// }
/// ```
#[derive(Clone)]
pub struct NonBlockStream<M: Message> {
    rx_reader: Arc<Mutex<Receiver<Box<M>>>>,
    tx_writer: Sender<Box<M>>,
    rx_err: Arc<Mutex<Receiver<std::io::Error>>>,
    local_addr: SocketAddr,
    remote_addr: SocketAddr,
    _handle: Arc<JoinHandle<()>>,
}

impl<M: Message> From<TcpStream> for NonBlockStream<M> {
    fn from(stream: TcpStream) -> Self {
        NonBlockStream::<M>::from_versions(Versions::default(), stream)
    }
}

impl<M: Message> NonBlockStream<M> {
    pub fn from_version_packs(v: Versions, packs: Packs, stream: TcpStream) -> Self {
        let local_addr = stream
            .local_addr()
            .expect("Could not obtain local_addr from stream");
        let remote_addr = stream
            .peer_addr()
            .expect("Could not obtain peer_addr from stream");
        stream
            .set_nonblocking(true)
            .expect("Could not set socket to nonblocking. It is required for communication.");

        let (tx_reader, rx_reader) = channel::<Box<M>>();
        let (tx_writer, rx_writer) = channel::<Box<M>>();
        let (tx_err, rx_err) = channel::<std::io::Error>();

        // The looper consumes the TcpStream.
        let looper = looper::StreamLooper::<M>::new(v, packs, stream, tx_reader, rx_writer, tx_err);
        let handle = std::thread::spawn(move || {
            looper.stream_loop();
        });

        Self {
            rx_reader: Arc::new(Mutex::new(rx_reader)),
            tx_writer,
            rx_err: Arc::new(Mutex::new(rx_err)),
            local_addr,
            remote_addr,
            _handle: Arc::new(handle),
        }
    }

    pub fn from_versions(v: Versions, stream: TcpStream) -> Self {
        Self::from_version_packs(v, Default::default(), stream)
    }

    /// The address of the local tcp stream.
    pub fn local_addr(&self) -> SocketAddr {
        self.local_addr
    }

    /// The address of the remote end of the tcp stream.
    pub fn remote_addr(&self) -> SocketAddr {
        self.remote_addr
    }

    /// Queue a new message for write.
    pub fn write(&mut self, msg: Box<M>) -> Result<(), std::io::Error> {
        self.trap_fault()?;
        self.trap_write(msg)
    }

    /// Check if there is a message available to read and return it.
    pub fn read(&mut self) -> Result<Option<Box<M>>, std::io::Error> {
        self.trap_fault()?;
        self.trap_recv()
    }

    fn trap_write(&mut self, msg: Box<M>) -> Result<(), std::io::Error> {
        self.tx_writer
            .send(msg)
            .map_err(|e| std::io::Error::new(ErrorKind::ConnectionAborted, e))
    }

    fn trap_recv(&mut self) -> Result<Option<Box<M>>, std::io::Error> {
        let op = self.rx_reader.lock().unwrap().try_recv();
        match op {
            Ok(msg) => Ok(Some(msg)),
            Err(e) => match e {
                TryRecvError::Empty => Ok(None),
                TryRecvError::Disconnected => {
                    Err(std::io::Error::new(ErrorKind::ConnectionAborted, e))
                }
            },
        }
    }

    fn trap_fault(&mut self) -> Result<(), std::io::Error> {
        let op = self.rx_err.lock().unwrap().try_recv();
        match op {
            Ok(f) => Err(f),
            Err(e) => match e {
                TryRecvError::Empty => Ok(()),
                TryRecvError::Disconnected => {
                    Err(std::io::Error::new(ErrorKind::ConnectionAborted, e))
                }
            },
        }
    }
}