tokio-connectors 0.1.6

A collection of connectors for tokio-based clients and servers
Documentation
use std::{fmt::Debug, marker::PhantomData};

use tokio::{
    io::Ready,
    select,
    sync::{
        broadcast::Sender as BroadcastSender,
        mpsc::{UnboundedReceiver, UnboundedSender},
    },
    task::JoinHandle,
};
use tracing::{debug, error, trace};

use crate::{codecs::Codec, error::Error};

/// Generic [Readable] trait over different (split) transport types
pub(crate) trait Readable: Send + Sync + 'static {
    // Wait for the underlying stream to be ready for reading (or closed/errored)
    fn ready_internal(&mut self) -> impl Future<Output = Result<Ready, std::io::Error>> + Send;

    // Try to read data from the underlying stream into the provided buffer
    fn try_read_buf_internal(&mut self, buf: &mut Vec<u8>) -> Result<usize, std::io::Error>;
}

/// Generic [Writeable] trait over different (split) transport types
pub(crate) trait Writeable: Send + Sync + 'static {
    // Write the provided buffer to the underlying stream, ensuring all data is sent
    fn write_all_internal(
        &mut self,
        buf: &[u8],
    ) -> impl Future<Output = Result<(), std::io::Error>> + Send;
}

/// A handle for an active connection, used for both the client and server implementations.
///
pub(crate) struct Handle<
    C: Codec<OUT, IN>,
    OUT: Debug + Send + 'static,
    IN: Debug + Send + 'static,
    ADDR: Debug + Clone + Sync + Send + 'static,
> {
    addr: ADDR,

    exit_tx: BroadcastSender<()>,

    _rx_handle: JoinHandle<()>,
    _tx_handle: JoinHandle<()>,

    _c: PhantomData<C>,
    _out: PhantomData<OUT>,
    _in: PhantomData<IN>,
}

impl<
    C: Codec<OUT, IN>,
    OUT: Debug + Send + 'static,
    IN: Debug + Send + 'static,
    ADDR: Debug + Clone + Sync + Send + 'static,
> Handle<C, OUT, IN, ADDR>
{
    /// Create tasks to read from and write to an existing TCP stream and channels
    pub(crate) async fn new(
        addr: ADDR,
        mut reader: impl Readable + Unpin + Send + 'static,
        mut writer: impl Writeable + Unpin + Send + 'static,
        mut out_rx: UnboundedReceiver<OUT>,
        in_tx: UnboundedSender<(IN, ADDR)>,
    ) -> Result<Self, Error> {
        // Setup the exit channel
        let (exit_tx, _exit_rx) = tokio::sync::broadcast::channel::<()>(1);

        // Setup a task to handle reading from the stream
        let rx_exit_tx = exit_tx.clone();
        let addr_ = addr.clone();
        let _rx_handle = tokio::task::spawn(async move {
            let mut accumulator = Vec::with_capacity(1024);
            let mut exit_rx = rx_exit_tx.subscribe();

            debug!("Reader task started for {addr_:?}");

            loop {
                select! {
                    biased;
                    // Handle exit signal
                    _ = exit_rx.recv() => {
                        debug!("Client reader exiting");
                        break;
                    },
                    // Poll for incoming data from the stream and accumulate it into the buffer
                    r = reader.ready_internal() => match r {
                        Ok(Ready::READABLE) => {
                            if let Err(e) = Self::handle_read(&mut reader, &mut accumulator, &addr_, &in_tx).await {
                                error!("Error handling read for {addr_:?}: {e:?}");
                                rx_exit_tx.send(()).ok();
                                break;
                            }
                        },
                        Ok(r) if r.is_read_closed() => {
                            debug!("Stream for {addr_:?} closed");
                            break;
                        }
                        Ok(r) if r.is_error() => {
                            debug!("Stream for {addr_:?} encountered an error");
                            break;
                        }
                        Ok(r) => {
                            // Unexpected readiness state, continue polling
                            debug!("Unexpected readiness state for {addr_:?}: {r:?}");
                            continue;
                        }
                        Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                            // No more data to read
                            continue;
                        }
                        Err(e) => {
                            error!("Failed to read from stream: {e:?}");
                            break;
                        }
                    },
                }
            }

            rx_exit_tx.send(()).ok();

            drop(reader);
        });

        // Setup a task to handle writing events to the stream
        let tx_exit_tx = exit_tx.clone();
        let addr_ = addr.clone();
        let _tx_handle = tokio::task::spawn(async move {
            let mut exit_rx = tx_exit_tx.subscribe();

            debug!("Writer task started for {addr_:?}");

            loop {
                select! {
                    biased;
                    _ = exit_rx.recv() => {
                        debug!("Client writer exiting");
                        break;
                    }
                    e = out_rx.recv() => match e {
                        Some(event) => {
                            // Serialize the event with the codec
                            let data = match C::encode(&event) {
                                Ok(data) => data,
                                Err(e) => {
                                    error!("Failed to serialize event: {:?}", e);
                                    continue;
                                }
                            };
                            // Write the event data to the stream
                            if let Err(e) = writer.write_all_internal(&data).await {
                                error!("Failed to write to stream: {:?}", e);
                                tx_exit_tx.send(()).ok();
                                break;
                            }
                        },
                        None => {
                            debug!("Failed to receive event, channel closed");
                            tx_exit_tx.send(()).ok();
                            break;
                        }
                    },
                }
            }

            tx_exit_tx.send(()).ok();

            drop(writer);
        });

        Ok(Self {
            addr,
            _rx_handle,
            _tx_handle,
            exit_tx,
            _c: PhantomData,
            _in: PhantomData,
            _out: PhantomData,
        })
    }

    async fn handle_read<READER: Readable>(
        reader: &mut READER,
        accumulator: &mut Vec<u8>,
        addr_: &ADDR,
        in_tx: &UnboundedSender<(IN, ADDR)>,
    ) -> Result<(), Error> {
        let mut total = 0;

        // Read new data from the stream into a temporary buffer
        let mut buff = Vec::with_capacity(1 * 1024 * 1024);
        'read: loop {
            buff.clear();

            let n = match reader.try_read_buf_internal(&mut buff) {
                Ok(n) => n,
                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                    // No more data to read
                    break 'read;
                }
                Err(e) => {
                    error!("Failed to read from stream: {:?}", e);
                    return Err(Error::Io(e));
                }
            };

            // Append the new data to the accumulator buffer for decoding
            accumulator.extend_from_slice(&buff[..n]);

            total += n;

            if n < buff.capacity() {
                // All available data has been read
                break 'read;
            }
        }

        if total == 0 {
            return Ok(());
        }

        trace!("Read {total} bytes from stream");
        trace!("Accumulated buffer size: {}", accumulator.len());

        // Try to parse complete messages from the codec
        loop {
            match C::try_decode(accumulator) {
                Ok(Some(cmd)) => {
                    debug!("Decoded message from {:?}: {:?}", addr_, cmd);
                    // Successfully parsed a complete message, forward it to the server
                    _ = in_tx.send((cmd, addr_.clone()));
                }
                Ok(None) => {
                    // Not enough data yet, wait for more
                    return Ok(());
                }
                Err(e) => {
                    error!("Failed to decode message: {:?}", e);
                    return Err(e);
                }
            }
        }
    }

    /// Fetch the target address of the TCP connection
    pub fn addr(&self) -> ADDR {
        self.addr.clone()
    }

    /// Register a callback to be called when the connection is closed
    pub fn on_closed<F: FnOnce() + Send + 'static>(&self, callback: F) {
        let mut exit_rx = self.exit_tx.subscribe();

        tokio::task::spawn(async move {
            debug!("Registering on_closed callback");
            let _ = exit_rx.recv().await;
            debug!("Connection closed, executing callback");
            callback();
        });
    }

    /// Exit the internal tasks and close the TCP connection
    pub fn close(self) -> Result<(), Error> {
        let _ = self.exit_tx.send(());
        Ok(())
    }
}