orcinus 0.3.0

async-ready mysql protocol implementation / wrapper libraries
Documentation
use std::{io::Read, pin::Pin, sync::atomic::AtomicUsize, task::Poll};

use tokio::io::AsyncRead;

pub struct ReadCounted<R> {
    inner: R,
    counter: AtomicUsize,
}
impl<R> ReadCounted<R> {
    pub const fn new(inner: R) -> Self {
        Self {
            inner,
            counter: AtomicUsize::new(0),
        }
    }

    pub fn into_inner(self) -> R {
        self.inner
    }

    pub fn read_bytes(&self) -> usize {
        self.counter.load(std::sync::atomic::Ordering::Acquire)
    }
}
impl<R> AsyncRead for ReadCounted<R>
where
    R: AsyncRead + Unpin,
{
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        let this = self.get_mut();

        match unsafe { Pin::new_unchecked(&mut this.inner).poll_read(cx, buf) } {
            Poll::Pending => Poll::Pending,
            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
            Poll::Ready(Ok(())) => {
                this.counter
                    .fetch_add(buf.filled().len(), std::sync::atomic::Ordering::AcqRel);
                Poll::Ready(Ok(()))
            }
        }
    }
}
impl<R> Read for ReadCounted<R>
where
    R: Read,
{
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        let bytes = R::read(&mut self.inner, buf)?;
        self.counter.fetch_add(bytes, std::sync::atomic::Ordering::AcqRel);
        Ok(bytes)
    }
}

pub struct ReadCountedSync<R> {
    inner: R,
    counter: usize,
}
impl<R> ReadCountedSync<R> {
    pub const fn new(inner: R) -> Self {
        Self { inner, counter: 0 }
    }

    pub fn into_inner(self) -> R {
        self.inner
    }

    pub fn read_bytes(&self) -> usize {
        self.counter
    }
}
impl<R> Read for ReadCountedSync<R>
where
    R: Read,
{
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        let bytes = R::read(&mut self.inner, buf)?;
        self.counter += bytes;
        Ok(bytes)
    }
}