protosocket 1.0.0

Message-oriented nonblocking tcp stream
Documentation
use std::ops::{Deref, DerefMut};

use crate::Encoder;

/// Raw serializer for a buffer pool
///
/// Serializers use plain reusable `Vec<u8>` to do the work of `Encoder`.
/// It's an easy compatibility approach, and it works great. But if you want
/// to do even better, and have a zero-copy message type, you can implement
/// `Encoder` directly and write your messages straight to the network.
pub trait Serialize {
    /// Outbound message type, to be serialized into buffers.
    type Message;

    /// Write a message into a pooled buffer
    fn serialize_into_buffer(&mut self, message: Self::Message, buffer: &mut Vec<u8>);
}

/// An encoder that wraps a serializer, offering it raw byte vectors. These vectors
/// are reset and reused to minimize allocation cost.
pub struct PooledEncoder<TSerializer> {
    serializer: TSerializer,
    buffer_pool: Vec<Vec<u8>>,
    max_pooled: usize,
}

impl<TSerializer> PooledEncoder<TSerializer>
where
    TSerializer: Serialize,
{
    /// Create a pooled encoder with default pool size.
    ///
    /// Note that small pool sizes are likely to do well even on heavily utilized systems.
    /// If the demand for concurrency is higher than the pool's size, then the pool will still
    /// settle to reuse with the additional concurrency held in flight. I.e., allocations are
    /// done in response to the first derivative of demand increase.
    pub fn new(serializer: TSerializer) -> Self {
        Self::new_with_pool_size(8, serializer)
    }

    /// Create a pooled encoder with explicit pool size.
    pub fn new_with_pool_size(pool_size: usize, serializer: TSerializer) -> Self {
        Self {
            serializer,
            buffer_pool: Vec::with_capacity(pool_size),
            max_pooled: pool_size,
        }
    }
}

impl<TSerializer> Default for PooledEncoder<TSerializer>
where
    TSerializer: Serialize + Default,
{
    fn default() -> Self {
        Self::new_with_pool_size(8, TSerializer::default())
    }
}

impl<TSerializer> Encoder for PooledEncoder<TSerializer>
where
    TSerializer: Serialize,
{
    type Message = TSerializer::Message;
    type Serialized = Reusable;

    #[cfg_attr(
        feature = "tracing",
        tracing::instrument(skip_all, name = "pooled_encode")
    )]
    fn encode(&mut self, message: Self::Message) -> Self::Serialized {
        let mut buffer = self.buffer_pool.pop().unwrap_or_default();
        self.serializer.serialize_into_buffer(message, &mut buffer);
        Reusable::new(buffer)
    }

    #[cfg_attr(
        feature = "tracing",
        tracing::instrument(skip_all, name = "pooled_return")
    )]
    fn return_buffer(&mut self, buffer: Self::Serialized) {
        if self.buffer_pool.len() < self.max_pooled {
            let mut buffer: Vec<u8> = buffer.inner;
            // SAFETY: u8 does not require drop and can be treated as MaybeUninit even when initialized.
            unsafe {
                buffer.set_len(0);
            }
            self.buffer_pool.push(buffer);
        }
    }
}

/// A reusable wrapper for a serializer buffer, which can be treated as a bytes::Buf.
#[derive(Debug, Clone)]
pub struct Reusable {
    inner: Vec<u8>,
    cursor: usize,
}

impl Reusable {
    fn new(buffer: Vec<u8>) -> Self {
        Self {
            inner: buffer,
            cursor: 0,
        }
    }
}

impl bytes::Buf for Reusable {
    #[inline(always)]
    fn remaining(&self) -> usize {
        self.inner.len() - self.cursor
    }

    #[inline(always)]
    fn chunk(&self) -> &[u8] {
        &self.inner[self.cursor..]
    }

    #[inline(always)]
    fn advance(&mut self, cnt: usize) {
        assert!(self.cursor + cnt <= self.inner.len());
        self.cursor += cnt;
    }

    #[inline(always)]
    fn chunks_vectored<'a>(&'a self, dst: &mut [std::io::IoSlice<'a>]) -> usize {
        if dst.is_empty() || self.cursor == self.inner.len() {
            0
        } else {
            dst[0] = std::io::IoSlice::new(self.chunk());
            1
        }
    }

    #[inline(always)]
    fn has_remaining(&self) -> bool {
        self.inner.len() != self.cursor
    }
}

impl Deref for Reusable {
    type Target = [u8];

    fn deref(&self) -> &Self::Target {
        &self.inner
    }
}

impl DerefMut for Reusable {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.inner
    }
}