cubic_protocol_server/
read.rs1use std::fs::read;
2use cubic_protocol::packet::{CustomError, InputPacketBytes, InputPacketBytesError, InputPacketBytesResult, PacketReadable, PacketReadableResult};
3use tokio::io::AsyncReadExt;
4use tokio::net::tcp::OwnedReadHalf;
5use cubic_protocol::types::VarInt;
6
7#[derive(Clone, Copy)]
8struct SlicePointer {
9 ptr: *mut u8,
10}
11
12unsafe impl Send for SlicePointer {}
13
14unsafe impl Sync for SlicePointer {}
15
16pub struct ReadStreamQueue<const BUFFER_SIZE: usize> {
17 read_half: OwnedReadHalf,
18 packet_length: usize,
19 packet_offset: usize,
20 buffer: [u8; BUFFER_SIZE],
21 buffer_size: usize,
22 buffer_offset: usize,
23}
24
25impl<const BUFFER_SIZE: usize> From<OwnedReadHalf> for ReadStreamQueue<BUFFER_SIZE> {
26 fn from(read_half: OwnedReadHalf) -> Self {
27 ReadStreamQueue::new(read_half)
28 }
29}
30
31impl<const BUFFER_SIZE: usize> From<ReadStreamQueue<BUFFER_SIZE>> for OwnedReadHalf {
32 fn from(queue: ReadStreamQueue<BUFFER_SIZE>) -> Self {
33 queue.read_half
34 }
35}
36
37impl<const BUFFER_SIZE: usize> ReadStreamQueue<BUFFER_SIZE> {
38 pub fn new(read_half: OwnedReadHalf) -> Self {
39 Self {
40 read_half,
41 packet_length: 0,
42 packet_offset: 0,
43 buffer: [0; BUFFER_SIZE],
44 buffer_size: 0,
45 buffer_offset: 0,
46 }
47 }
48
49 pub fn close(self) -> (OwnedReadHalf, Box<[u8]>) {
50 (
51 self.read_half,
52 self.buffer[self.buffer_offset..self.buffer_size].into(),
53 )
54 }
55
56 async fn read_next_bytes(&mut self) -> InputPacketBytesResult<()> {
57 match self.read_half.read(&mut self.buffer).await {
58 Ok(0) | Err(_) => Err(
59 CustomError::StaticStr("Connection was closed during reading").into()
60 ),
61 Ok(len) => {
62 self.buffer_size = len;
63 self.buffer_offset = 0;
64 log::debug!("Received bytes: {:?}", &self.buffer[0..self.buffer_size]);
65 Ok(())
66 }
67 }
68 }
69
70 async fn read_next_bytes_if_need(&mut self) -> InputPacketBytesResult<()> {
71 match self.buffer_offset == self.buffer_size {
72 true => self.read_next_bytes().await,
73 false => Ok(())
74 }
75 }
76
77 pub async fn next_packet(&mut self) -> PacketReadableResult<()> {
78 self.packet_length = 5; self.packet_offset = 0;
80 self.packet_length = <VarInt as PacketReadable>::read(self).await?.0 as usize;
81 self.packet_offset = 0;
82 Ok(())
83 }
84
85 async unsafe fn copy_into(&mut self, mut dst: SlicePointer, count: usize) -> InputPacketBytesResult<()> {
86 let mut offset: usize = 0;
87 loop {
88 let can_copy = self.buffer_size - self.buffer_offset;
89 let need_copy = count - offset;
90 match need_copy > can_copy {
91 true => {
92 std::ptr::copy_nonoverlapping(
93 self.buffer.as_ptr().add(self.buffer_offset), dst.ptr, can_copy,
94 );
95 dst.ptr = dst.ptr.add(can_copy);
96 offset += can_copy;
97 self.read_next_bytes().await?
98 }
99 false => {
100 std::ptr::copy_nonoverlapping(
101 self.buffer.as_ptr().add(self.buffer_offset), dst.ptr, need_copy,
102 );
103 self.buffer_offset += need_copy;
104 break Ok(());
105 }
106 }
107 }
108 }
109}
110
111#[async_trait::async_trait]
112impl<const BUFFER_SIZE: usize> InputPacketBytes for ReadStreamQueue<BUFFER_SIZE> {
113 async fn take_byte(&mut self) -> InputPacketBytesResult<u8> {
114 match self.packet_offset == self.packet_length {
115 true => Err(InputPacketBytesError::NoBytes(self.packet_length)),
116 false => {
117 self.read_next_bytes_if_need().await?;
118 let byte = self.buffer[self.buffer_offset];
119 self.buffer_offset += 1;
120 self.packet_offset += 1;
121 Ok(byte)
122 }
123 }
124 }
125
126 async fn take_slice(&mut self, slice: &mut [u8]) -> InputPacketBytesResult<()> {
127 match self.has_bytes(slice.len()) {
128 true => unsafe {
129 let slice_pointer = SlicePointer { ptr: slice.as_mut_ptr() };
130 self.copy_into(slice_pointer, slice.len()).await
131 },
132 false => Err(InputPacketBytesError::NoBytes(self.packet_length)),
133 }
134 }
135
136 async fn take_vec(&mut self, vec: &mut Vec<u8>, count: usize) -> InputPacketBytesResult<()> {
137 match self.has_bytes(count) {
138 true => unsafe {
139 vec.resize(count, 0);
140 let slice_pointer = SlicePointer { ptr: vec.as_mut_ptr() };
141 self.copy_into(slice_pointer, count).await
142 },
143 false => Err(InputPacketBytesError::NoBytes(self.packet_length))
144 }
145 }
146
147 fn has_bytes(&self, count: usize) -> bool {
148 self.remaining_bytes() >= count
149 }
150
151 fn remaining_bytes(&self) -> usize {
152 match self.packet_length > self.packet_offset {
153 true => self.packet_length - self.packet_offset,
154 false => 0
155 }
156 }
157}