1use bytes::Bytes;
2use futures::{SinkExt, Stream, StreamExt};
3use skrillax_codec::{SilkroadCodec, SilkroadFrame};
4use skrillax_packet::{
5 AsFrames, FramingError, FromFrames, IncomingPacket, OutgoingPacket, Packet, PacketError,
6 ReframingError, SecurityBytes, SecurityContext, TryFromPacket,
7};
8use skrillax_security::SilkroadEncryption;
9use std::io;
10use std::sync::Arc;
11use thiserror::Error;
12use tokio::io::{AsyncRead, AsyncWrite};
13use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
14use tokio::net::TcpStream;
15use tokio_util::codec::{FramedRead, FramedWrite};
16
17#[derive(Debug, Error)]
22pub enum OutStreamError {
23 #[error("Some IO level error occurred")]
26 IoError(#[from] io::Error),
27 #[error("Error occurred when trying to create frames")]
31 Framing(#[from] FramingError),
32}
33
34#[derive(Debug, Error)]
39pub enum InStreamError {
40 #[error("Some IO level error occurred")]
45 IoError(#[from] io::Error),
46 #[error("Error occurred at the packet level")]
47 PacketError(#[from] PacketError),
48 #[error("Error when trying to turn frames into packets")]
49 ReframingError(#[from] ReframingError),
50 #[error("Reached the end of the stream")]
52 EndOfStream,
53 #[error("Received unexpected opcode: {0:#06x}")]
56 UnmatchedOpcode(u16),
57}
58
59pub trait InputProtocol {
61 type Proto: Send;
62
63 fn create_from(opcode: u16, data: &[u8]) -> Result<(usize, Self::Proto), InStreamError>;
64}
65
66impl<T: TryFromPacket + Packet + Send> InputProtocol for T {
67 type Proto = T;
68
69 fn create_from(opcode: u16, data: &[u8]) -> Result<(usize, T), InStreamError> {
70 if opcode != T::ID {
71 return Err(InStreamError::UnmatchedOpcode(opcode));
72 }
73
74 Ok(T::try_deserialize(data)?)
75 }
76}
77
78pub trait SilkroadTcpExt {
81 fn into_silkroad_stream(
96 self,
97 ) -> (
98 SilkroadStreamRead<OwnedReadHalf>,
99 SilkroadStreamWrite<OwnedWriteHalf>,
100 );
101}
102
103impl SilkroadTcpExt for TcpStream {
104 fn into_silkroad_stream(
105 self,
106 ) -> (
107 SilkroadStreamRead<OwnedReadHalf>,
108 SilkroadStreamWrite<OwnedWriteHalf>,
109 ) {
110 let (read, write) = self.into_split();
111 let reader = FramedRead::new(read, SilkroadCodec);
112 let writer = FramedWrite::new(write, SilkroadCodec);
113
114 let stream_reader = SilkroadStreamRead::new(reader);
115 let stream_writer = SilkroadStreamWrite::new(writer);
116
117 (stream_reader, stream_writer)
118 }
119}
120
121pub struct SilkroadStreamWrite<T: AsyncWrite + Unpin> {
126 writer: FramedWrite<T, SilkroadCodec>,
127 encryption: Option<Arc<SilkroadEncryption>>,
128 security_bytes: Option<Arc<SecurityBytes>>,
129}
130
131impl<T: AsyncWrite + Unpin> SilkroadStreamWrite<T> {
132 fn new(writer: FramedWrite<T, SilkroadCodec>) -> Self {
133 Self {
134 writer,
135 encryption: None,
136 security_bytes: None,
137 }
138 }
139
140 #[allow(unused)]
141 fn with_encryption(
142 writer: FramedWrite<T, SilkroadCodec>,
143 encryption: Arc<SilkroadEncryption>,
144 security_bytes: Arc<SecurityBytes>,
145 ) -> Self {
146 Self {
147 writer,
148 encryption: Some(encryption),
149 security_bytes: Some(security_bytes),
150 }
151 }
152
153 pub fn enable_encryption(&mut self, encryption: Arc<SilkroadEncryption>) {
154 self.encryption = Some(encryption);
155 }
156
157 pub fn enable_security_checks(&mut self, security_bytes: Arc<SecurityBytes>) {
158 self.security_bytes = Some(security_bytes);
159 }
160
161 pub fn encryption(&self) -> Option<&SilkroadEncryption> {
162 self.encryption.as_deref()
163 }
164
165 pub fn security_bytes(&self) -> Option<&SecurityBytes> {
166 self.security_bytes.as_deref()
167 }
168
169 pub fn security_context(&self) -> SecurityContext {
170 SecurityContext::new(self.encryption(), self.security_bytes())
171 }
172
173 pub async fn write(&mut self, packet: OutgoingPacket) -> Result<(), OutStreamError> {
174 let frames = packet.as_frames(self.security_context())?;
175 for frame in frames {
176 self.writer.send(frame).await?;
177 }
178 Ok(())
179 }
180
181 pub async fn write_packet<S: Into<OutgoingPacket>>(
182 &mut self,
183 packet: S,
184 ) -> Result<(), OutStreamError> {
185 let outgoing_packet = packet.into();
186 self.write(outgoing_packet).await
187 }
188}
189
190pub struct SilkroadStreamRead<T: AsyncRead + Unpin> {
195 reader: FramedRead<T, SilkroadCodec>,
196 encryption: Option<Arc<SilkroadEncryption>>,
197 security_bytes: Option<Arc<SecurityBytes>>,
198 unconsumed: Option<(u16, Bytes)>,
199}
200
201impl<T: AsyncRead + Unpin> SilkroadStreamRead<T>
202where
203 FramedRead<T, SilkroadCodec>: Stream<Item = Result<SilkroadFrame, io::Error>>,
204{
205 fn new(reader: FramedRead<T, SilkroadCodec>) -> Self {
206 Self {
207 reader,
208 encryption: None,
209 security_bytes: None,
210 unconsumed: None,
211 }
212 }
213
214 #[allow(unused)]
215 fn with_encryption(
216 reader: FramedRead<T, SilkroadCodec>,
217 encryption: Arc<SilkroadEncryption>,
218 security_bytes: Arc<SecurityBytes>,
219 ) -> Self {
220 Self {
221 reader,
222 encryption: Some(encryption),
223 security_bytes: Some(security_bytes),
224 unconsumed: None,
225 }
226 }
227
228 pub fn enable_encryption(&mut self, encryption: Arc<SilkroadEncryption>) {
238 self.encryption = Some(encryption);
239 }
240
241 pub fn enable_security_checks(&mut self, security_bytes: Arc<SecurityBytes>) {
250 self.security_bytes = Some(security_bytes);
251 }
252
253 pub fn encryption(&self) -> Option<&SilkroadEncryption> {
255 self.encryption.as_deref()
256 }
257
258 pub fn security_bytes(&self) -> Option<&SecurityBytes> {
260 self.security_bytes.as_deref()
261 }
262
263 pub fn security_context(&self) -> SecurityContext {
271 SecurityContext::new(self.encryption(), self.security_bytes())
272 }
273
274 pub async fn next(&mut self) -> Result<IncomingPacket, InStreamError> {
290 let mut buffer = Vec::new();
291 let mut remaining = 1usize;
292 while let Some(res) = self.reader.next().await {
293 let frame = res?;
294 buffer.push(frame);
295 remaining -= 1;
296 if remaining == 0 {
297 match IncomingPacket::from_frames(&buffer, self.security_context()) {
298 Ok(packet) => return Ok(packet),
299 Err(ReframingError::Incomplete(required)) => {
300 remaining += required.unwrap_or(1);
301 },
302 Err(e) => return Err(InStreamError::ReframingError(e)),
303 }
304 }
305 }
306
307 Err(InStreamError::EndOfStream)
308 }
309
310 pub async fn next_packet<S: InputProtocol>(&mut self) -> Result<S::Proto, InStreamError> {
322 let (opcode, mut buffer) = match self.unconsumed.take() {
323 Some(inner) => inner,
324 _ => self.next().await?.consume(),
325 };
326
327 let (consumed, p) = S::create_from(opcode, &buffer)?;
328 let _ = buffer.split_to(consumed);
329 if !buffer.is_empty() {
330 self.unconsumed = Some((opcode, buffer));
331 }
332
333 Ok(p)
334 }
335}
336
337#[cfg(test)]
338mod test {
339 use super::*;
340 use skrillax_serde::{ByteSize, Deserialize, Serialize};
341
342 #[derive(Packet, Deserialize, Serialize, ByteSize)]
343 #[packet(opcode = 0x0042)]
344 struct Empty;
345
346 #[tokio::test]
347 pub async fn test_read_packet_from_stream() {
348 let buffer: &[u8] = &[0x00, 0x00, 0x42, 0x00, 0x00, 0x00];
349 let mut reader = SilkroadStreamRead::new(FramedRead::new(buffer, SilkroadCodec));
350 let _ = reader
351 .next_packet::<Empty>()
352 .await
353 .expect("Should read empty packet.");
354 }
355
356 #[tokio::test]
357 pub async fn test_write_packet_to_stream() {
358 let mut buffer: Vec<u8> = Vec::new();
359 let mut writer = SilkroadStreamWrite::new(FramedWrite::new(&mut buffer, SilkroadCodec));
360 writer
361 .write_packet(Empty)
362 .await
363 .expect("Should write empty packet.");
364 drop(writer);
365 let content: &[u8] = &buffer;
366 assert_eq!(&[0x00u8, 0x00, 0x42, 0x00, 0x00, 0x00], content);
367 }
368}