bp3d_net/tcp/util/
buffer.rs

1// Copyright (c) 2025, BlockProject 3D
2//
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without modification,
6// are permitted provided that the following conditions are met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//     * Redistributions in binary form must reproduce the above copyright notice,
11//       this list of conditions and the following disclaimer in the documentation
12//       and/or other materials provided with the distribution.
13//     * Neither the name of BlockProject 3D nor the names of its contributors
14//       may be used to endorse or promote products derived from this software
15//       without specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
29//! Async buffered receiver based on mpsc channels.
30
31use crate::tcp::BYTES_BUFFER_SIZE;
32use std::fmt::{Debug, Formatter};
33use std::io::{Error, ErrorKind};
34use std::net::SocketAddr;
35use std::ops::Deref;
36use std::pin::Pin;
37use std::task::{Context, Poll};
38use tokio::io::{AsyncRead, ReadBuf};
39use tokio::sync::mpsc;
40
41/// A single buffer of bytes.
42pub struct Bytes<const N: usize> {
43    bytes: [u8; N],
44    size: usize,
45}
46
47impl<const N: usize> Bytes<N> {
48    /// Creates a new owned byte buffer.
49    ///
50    /// # Arguments
51    ///
52    /// * `bytes`: the array of bytes to store.
53    /// * `size`: the number of valid bytes in the buffer.
54    ///
55    /// returns: Bytes<{ N }>
56    pub fn new(bytes: [u8; N], size: usize) -> Self {
57        Self { bytes, size }
58    }
59}
60
61impl<const N: usize> Deref for Bytes<N> {
62    type Target = [u8];
63
64    fn deref(&self) -> &Self::Target {
65        &self.bytes[..self.size]
66    }
67}
68
69/// The main async channel buffer.
70pub struct ChannelBuffer<const N: usize> {
71    receiver: mpsc::Receiver<Bytes<N>>,
72}
73
74impl<const N: usize> ChannelBuffer<N> {
75    /// Creates a new [ChannelBuffer] by wrapping a mpsc channel.
76    ///
77    /// # Arguments
78    ///
79    /// * `receiver`: the channel to wrap.
80    ///
81    /// returns: ChannelBuffer<{ N }>
82    pub fn new(receiver: mpsc::Receiver<Bytes<N>>) -> Self {
83        Self { receiver }
84    }
85
86    /// Closes this [ChannelBuffer].
87    pub fn close(mut self) {
88        self.receiver.close();
89    }
90}
91
92impl<const N: usize> AsyncRead for ChannelBuffer<N> {
93    fn poll_read(
94        mut self: Pin<&mut Self>,
95        cx: &mut Context<'_>,
96        buf: &mut ReadBuf<'_>,
97    ) -> Poll<std::io::Result<()>> {
98        let msg = self.receiver.poll_recv(cx);
99        match msg {
100            Poll::Ready(v) => match v {
101                Some(bytes) => {
102                    buf.put_slice(&bytes);
103                    Poll::Ready(Ok(()))
104                }
105                None => Poll::Ready(Err(Error::new(
106                    ErrorKind::BrokenPipe,
107                    "channel buffer is closed",
108                ))),
109            },
110            Poll::Pending => Poll::Pending,
111        }
112    }
113}
114
115/// Represents a network receiver which wraps a [ChannelBuffer].
116pub struct NetReceiver {
117    channel_buffer: ChannelBuffer<BYTES_BUFFER_SIZE>,
118    addr: SocketAddr,
119    id: usize,
120}
121
122impl Debug for NetReceiver {
123    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
124        write!(
125            f,
126            "NetReceiver {{ addr: {:?}, id: {:?} }}",
127            self.addr, self.id
128        )
129    }
130}
131
132impl NetReceiver {
133    /// Creates a new instance of a NetReceiver.
134    ///
135    /// # Arguments
136    ///
137    /// * `channel_buffer`: the channel buffer to read bytes from.
138    /// * `addr`: the peer address.
139    /// * `id`: the network id associated to the client.
140    ///
141    /// returns: NetReceiver
142    pub fn new(
143        channel_buffer: ChannelBuffer<BYTES_BUFFER_SIZE>,
144        addr: SocketAddr,
145        id: usize,
146    ) -> Self {
147        Self {
148            channel_buffer,
149            addr,
150            id,
151        }
152    }
153
154    /// Returns the socket address.
155    pub fn addr(&self) -> &SocketAddr {
156        &self.addr
157    }
158
159    /// Returns the unique network ID.
160    pub fn id(&self) -> usize {
161        self.id
162    }
163
164    /// Closes this [ChannelBuffer].
165    pub fn close(self) {
166        self.channel_buffer.close();
167    }
168}
169
170impl AsyncRead for NetReceiver {
171    fn poll_read(
172        self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        buf: &mut ReadBuf<'_>,
175    ) -> Poll<std::io::Result<()>> {
176        unsafe {
177            self.map_unchecked_mut(|v| &mut v.channel_buffer)
178                .poll_read(cx, buf)
179        }
180    }
181}