async_io_bufpool/
lib.rs

1#![doc = pretty_readme::docify!("README.md", "https://docs.rs/super-cool-crate/latest/super-cool-crate/", "./")]
2
3use std::{cell::RefCell, future::Future};
4
5use bytes::Bytes;
6
7use crossbeam_queue::SegQueue;
8use futures_util::{AsyncRead, AsyncWrite};
9use pin_project_lite::pin_project;
10
11thread_local! {
12    static BUFFER: RefCell<[u8; 65536]> = const { RefCell::new([0u8; 65536]) }
13}
14
15/// Read an async reader into a buffer. This is done in a memory-efficient way, avoiding consuming any memory before the read unblocks.
16///
17/// An empty return value indicates EOF.
18pub async fn pooled_read(
19    rdr: impl AsyncRead,
20    limit: usize,
21) -> Result<Option<Bytes>, std::io::Error> {
22    PooledOnceReader {
23        rdr,
24        resolve: |b: &[u8]| Bytes::copy_from_slice(b),
25        limit,
26    }
27    .await
28}
29
30/// Read an async reader into a buffer, but instead of allocating memory, call a callback.
31///
32/// An empty return value indicates EOF.
33pub async fn pooled_read_callback<T>(
34    rdr: impl AsyncRead,
35    limit: usize,
36    resolve: impl FnMut(&[u8]) -> T,
37) -> Result<Option<T>, std::io::Error> {
38    PooledOnceReader {
39        rdr,
40        resolve,
41        limit,
42    }
43    .await
44}
45
46/// Copy data from an async reader to an async writer using a thread-local buffer.
47/// Returns the total number of bytes copied.
48pub async fn pooled_copy<R, W>(mut reader: R, mut writer: W) -> std::io::Result<u64>
49where
50    R: AsyncRead + Unpin,
51    W: AsyncWrite + Unpin,
52{
53    let mut total_bytes = 0u64;
54
55    static BUFFS: SegQueue<Box<[u8; 8192]>> = SegQueue::new();
56
57    loop {
58        let (buff, n) = match pooled_read_callback(&mut reader, 8192, |bts| {
59            let mut buff = BUFFS.pop().unwrap_or_else(|| Box::new([0u8; 8192]));
60            buff[..bts.len()].copy_from_slice(bts);
61            (buff, bts.len())
62        })
63        .await?
64        {
65            Some(x) => x,
66            None => break, // End of file
67        };
68
69        let bytes_read = n as u64;
70        futures_util::AsyncWriteExt::write_all(&mut writer, &buff[..n]).await?;
71        total_bytes += bytes_read;
72    }
73
74    Ok(total_bytes)
75}
76
77pin_project! {
78struct PooledOnceReader<T, F>{
79    #[pin]
80    rdr: T,
81    resolve: F,
82    limit: usize
83}
84}
85impl<T: AsyncRead, U, F: FnMut(&[u8]) -> U> Future for PooledOnceReader<T, F> {
86    type Output = Result<Option<U>, std::io::Error>;
87
88    fn poll(
89        self: std::pin::Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91    ) -> std::task::Poll<Self::Output> {
92        BUFFER.with(|buf| {
93            let mut buf = buf.borrow_mut();
94            let this = self.project();
95            let limit = (*this.limit).min(buf.len());
96            match this.rdr.poll_read(cx, &mut buf[..limit]) {
97                std::task::Poll::Ready(Ok(n)) => {
98                    if n == 0 {
99                        std::task::Poll::Ready(Ok(None))
100                    } else {
101                        std::task::Poll::Ready(Ok(Some((this.resolve)(&buf[..n]))))
102                    }
103                }
104                std::task::Poll::Ready(Err(err)) => std::task::Poll::Ready(Err(err)),
105                std::task::Poll::Pending => std::task::Poll::Pending,
106            }
107        })
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    use pollster::FutureExt;
116
117    #[test]
118    fn test_pooled_read() {
119        // Create test data
120        let test_data = b"Hello, World!";
121
122        // Run the pooled_read function
123        let result = pooled_read(&test_data[..], 10000).block_on();
124
125        // Verify the result
126        assert!(result.is_ok());
127        let bytes = result.unwrap();
128        assert_eq!(bytes, Some(Bytes::from_static(test_data)));
129        assert_eq!(bytes.unwrap().len(), test_data.len());
130    }
131}