cubic_protocol_server/
read.rs

1use std::fs::read;
2use cubic_protocol::packet::{CustomError, InputPacketBytes, InputPacketBytesError, InputPacketBytesResult, PacketReadable, PacketReadableResult};
3use tokio::io::AsyncReadExt;
4use tokio::net::tcp::OwnedReadHalf;
5use cubic_protocol::types::VarInt;
6
7#[derive(Clone, Copy)]
8struct SlicePointer {
9    ptr: *mut u8,
10}
11
12unsafe impl Send for SlicePointer {}
13
14unsafe impl Sync for SlicePointer {}
15
16pub struct ReadStreamQueue<const BUFFER_SIZE: usize> {
17    read_half: OwnedReadHalf,
18    packet_length: usize,
19    packet_offset: usize,
20    buffer: [u8; BUFFER_SIZE],
21    buffer_size: usize,
22    buffer_offset: usize,
23}
24
25impl<const BUFFER_SIZE: usize> From<OwnedReadHalf> for ReadStreamQueue<BUFFER_SIZE> {
26    fn from(read_half: OwnedReadHalf) -> Self {
27        ReadStreamQueue::new(read_half)
28    }
29}
30
31impl<const BUFFER_SIZE: usize> From<ReadStreamQueue<BUFFER_SIZE>> for OwnedReadHalf {
32    fn from(queue: ReadStreamQueue<BUFFER_SIZE>) -> Self {
33        queue.read_half
34    }
35}
36
37impl<const BUFFER_SIZE: usize> ReadStreamQueue<BUFFER_SIZE> {
38    pub fn new(read_half: OwnedReadHalf) -> Self {
39        Self {
40            read_half,
41            packet_length: 0,
42            packet_offset: 0,
43            buffer: [0; BUFFER_SIZE],
44            buffer_size: 0,
45            buffer_offset: 0,
46        }
47    }
48
49    pub fn close(self) -> (OwnedReadHalf, Box<[u8]>) {
50        (
51            self.read_half,
52            self.buffer[self.buffer_offset..self.buffer_size].into(),
53        )
54    }
55
56    async fn read_next_bytes(&mut self) -> InputPacketBytesResult<()> {
57        match self.read_half.read(&mut self.buffer).await {
58            Ok(0) | Err(_) => Err(
59                CustomError::StaticStr("Connection was closed during reading").into()
60            ),
61            Ok(len) => {
62                self.buffer_size = len;
63                self.buffer_offset = 0;
64                log::debug!("Received bytes: {:?}", &self.buffer[0..self.buffer_size]);
65                Ok(())
66            }
67        }
68    }
69
70    async fn read_next_bytes_if_need(&mut self) -> InputPacketBytesResult<()> {
71        match self.buffer_offset == self.buffer_size {
72            true => self.read_next_bytes().await,
73            false => Ok(())
74        }
75    }
76
77    pub async fn next_packet(&mut self) -> PacketReadableResult<()> {
78        self.packet_length = 5; // maximum VarInt length
79        self.packet_offset = 0;
80        self.packet_length = <VarInt as PacketReadable>::read(self).await?.0 as usize;
81        self.packet_offset = 0;
82        Ok(())
83    }
84
85    async unsafe fn copy_into(&mut self, mut dst: SlicePointer, count: usize) -> InputPacketBytesResult<()> {
86        let mut offset: usize = 0;
87        loop {
88            let can_copy = self.buffer_size - self.buffer_offset;
89            let need_copy = count - offset;
90            match need_copy > can_copy {
91                true => {
92                    std::ptr::copy_nonoverlapping(
93                        self.buffer.as_ptr().add(self.buffer_offset), dst.ptr, can_copy,
94                    );
95                    dst.ptr = dst.ptr.add(can_copy);
96                    offset += can_copy;
97                    self.read_next_bytes().await?
98                }
99                false => {
100                    std::ptr::copy_nonoverlapping(
101                        self.buffer.as_ptr().add(self.buffer_offset), dst.ptr, need_copy,
102                    );
103                    self.buffer_offset += need_copy;
104                    break Ok(());
105                }
106            }
107        }
108    }
109}
110
111#[async_trait::async_trait]
112impl<const BUFFER_SIZE: usize> InputPacketBytes for ReadStreamQueue<BUFFER_SIZE> {
113    async fn take_byte(&mut self) -> InputPacketBytesResult<u8> {
114        match self.packet_offset == self.packet_length {
115            true => Err(InputPacketBytesError::NoBytes(self.packet_length)),
116            false => {
117                self.read_next_bytes_if_need().await?;
118                let byte = self.buffer[self.buffer_offset];
119                self.buffer_offset += 1;
120                self.packet_offset += 1;
121                Ok(byte)
122            }
123        }
124    }
125
126    async fn take_slice(&mut self, slice: &mut [u8]) -> InputPacketBytesResult<()> {
127        match self.has_bytes(slice.len()) {
128            true => unsafe {
129                let slice_pointer = SlicePointer { ptr: slice.as_mut_ptr() };
130                self.copy_into(slice_pointer, slice.len()).await
131            },
132            false => Err(InputPacketBytesError::NoBytes(self.packet_length)),
133        }
134    }
135
136    async fn take_vec(&mut self, vec: &mut Vec<u8>, count: usize) -> InputPacketBytesResult<()> {
137        match self.has_bytes(count) {
138            true => unsafe {
139                vec.resize(count, 0);
140                let slice_pointer = SlicePointer { ptr: vec.as_mut_ptr() };
141                self.copy_into(slice_pointer, count).await
142            },
143            false => Err(InputPacketBytesError::NoBytes(self.packet_length))
144        }
145    }
146
147    fn has_bytes(&self, count: usize) -> bool {
148        self.remaining_bytes() >= count
149    }
150
151    fn remaining_bytes(&self) -> usize {
152        match self.packet_length > self.packet_offset {
153            true => self.packet_length - self.packet_offset,
154            false => 0
155        }
156    }
157}